From ce16351d16883646a34d246c56a2dddfb561dcf2 Mon Sep 17 00:00:00 2001 From: zhuyuxiao Date: Tue, 17 May 2022 16:18:22 +0800 Subject: [PATCH] add CPU support for KLDivLoss op --- .../cpu/kernel/kl_div_loss_cpu_kernel.cc | 201 ++++++++++++++++++ .../cpu/kernel/kl_div_loss_cpu_kernel.h | 77 +++++++ .../cuda_ops/loss_with_reduction_impl.cu | 8 + .../gpu/kernel/nn/kl_div_loss_gpu_kernel.cc | 4 + mindspore/core/ops/core_ops.h | 1 + mindspore/core/ops/kl_div_loss.cc | 83 ++++++++ mindspore/core/ops/kl_div_loss.h | 52 +++++ mindspore/core/ops/op_name.h | 3 + .../python/mindspore/ops/operations/nn_ops.py | 39 ++-- tests/st/ops/cpu/test_kl_div_op.py | 109 ++++++++++ tests/ut/python/parallel/test_kldiv_loss.py | 8 +- 11 files changed, 563 insertions(+), 22 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/kl_div_loss_cpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/kl_div_loss_cpu_kernel.h create mode 100644 mindspore/core/ops/kl_div_loss.cc create mode 100644 mindspore/core/ops/kl_div_loss.h create mode 100644 tests/st/ops/cpu/test_kl_div_op.py diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/kl_div_loss_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/kl_div_loss_cpu_kernel.cc new file mode 100644 index 00000000000..ee9b5d4d82a --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/kl_div_loss_cpu_kernel.cc @@ -0,0 +1,201 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/kl_div_loss_cpu_kernel.h" +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "mindspore/core/ops/kl_div_loss.h" +#include "include/common/thread_pool.h" + +namespace mindspore { +namespace kernel { +const size_t kMyAddInputsNum = 2; +const size_t kMyAddOutputsNum = 1; + +bool KLDivLossCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + auto kernel_ptr = std::dynamic_pointer_cast(base_operator); + if (!kernel_ptr) { + MS_LOG(EXCEPTION) << "cast KLDivLoss ops failed!"; + return false; + } + kernel_name_ = kernel_ptr->name(); + reductionMode_ = kernel_ptr->get_reduction(); + batch_size_ = inputs[kIndex0]->GetShapeVector()[kIndex0]; + + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMyAddInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMyAddOutputsNum, kernel_name_); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this data type: " << kernel_attr; + } + + kernel_func_ = func_list_[index].second; + return true; +} + +bool KLDivLossCpuKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(kernel_func_); + return kernel_func_(this, inputs, workspace, outputs); +} + +int KLDivLossCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &onHost) { + int ret = 0; + ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, onHost); + if (ret != 0) { + MS_LOG(WARNING) << kernel_name_ << " reinit failed."; + return ret; + } + + std::vector input_x_shape = inputs[kIndex0]->GetShapeVector(); + for (size_t i = 0; i < input_x_shape.size(); ++i) { + input_x_shape_size_ *= input_x_shape[i]; + } + + std::vector input_target_shape = inputs[kIndex1]->GetShapeVector(); + for (size_t i = 0; i < input_target_shape.size(); ++i) { + input_target_shape_size_ *= input_target_shape[i]; + } + + if (reductionMode_ != ops::kNone) { + size_t type_size = GetTypeByte(TypeIdToType(inputs[kIndex0]->GetDtype())); + workspace_size_list_.push_back(input_x_shape_size_ * type_size); + } + return ret; +} + +std::vector KLDivLossCpuKernelMod::GetOpSupport() { + static std::vector support_list; + (void)std::transform( + func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +bool KLDivLossCpuKernelMod::CheckParams() { + // for kl div, shape size of input 0 and input 1 must be the same + if (input_target_shape_size_ != input_x_shape_size_) { + MS_LOG(ERROR) << kernel_name_ << ": input x shape size = " << input_x_shape_size_ + << ", input target shape size = " << input_target_shape_size_ << ". They are not the same."; + return false; + } + return true; +} + +template +bool KLDivLossCpuKernelMod::LaunchNoneReduction(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + auto *input_x = reinterpret_cast(inputs[kIndex0]->addr); + auto *input_target = reinterpret_cast(inputs[kIndex1]->addr); + auto *y = reinterpret_cast(outputs[kIndex0]->addr); + + Eigen::Map> array_x(input_x, input_x_shape_size_, 1); + Eigen::Map> array_target(input_target, input_target_shape_size_, 1); + Eigen::Map> array_y(y, input_x_shape_size_, 1); + + array_y = array_target * (Eigen::log(array_target) - array_x); + + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; ++i) { + if (std::isnan(static_cast(array_y[i]))) { + array_y[i] = static_cast(0); + } + } + }; + ParallelLaunchAutoSearch(task, input_x_shape_size_, this, ¶llel_search_info_); + + return true; +} + +template +bool KLDivLossCpuKernelMod::LaunchOther(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + auto *input_x = reinterpret_cast(inputs[kIndex0]->addr); + auto *input_target = reinterpret_cast(inputs[kIndex1]->addr); + auto *tmp_result = reinterpret_cast(workspace[kIndex0]->addr); + auto *y = reinterpret_cast(outputs[kIndex0]->addr); + + Eigen::Map> array_x(input_x, input_x_shape_size_, 1); + Eigen::Map> array_target(input_target, input_target_shape_size_, 1); + Eigen::Map> array_tmp(tmp_result, input_x_shape_size_, 1); + + array_tmp = array_target * (Eigen::log(array_target) - array_x); + + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; ++i) { + if (std::isnan(static_cast(array_tmp[i]))) { + array_tmp[i] = static_cast(0); + } + } + }; + ParallelLaunchAutoSearch(task, input_x_shape_size_, this, ¶llel_search_info_); + + if (reductionMode_ == ops::kSum) { + y[kIndex0] = array_tmp.sum(); + return true; + } + + if (reductionMode_ == ops::kMean) { + y[kIndex0] = array_tmp.mean(); + return true; + } + + if (reductionMode_ == ops::kBatchMean) { + y[kIndex0] = array_tmp.sum() / static_cast(batch_size_); + return true; + } + + return false; +} + +template +bool KLDivLossCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMyAddInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMyAddOutputsNum, kernel_name_); + if (!KLDivLossCpuKernelMod::CheckParams()) { + MS_LOG(EXCEPTION) << kernel_name_ << ": check param failed."; + } + + if (reductionMode_ == ops::kNone) { + return LaunchNoneReduction(inputs, workspace, outputs); + } + + return LaunchOther(inputs, workspace, outputs); +} + +std::vector> KLDivLossCpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &KLDivLossCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &KLDivLossCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &KLDivLossCpuKernelMod::LaunchKernel}}; + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, KLDivLoss, KLDivLossCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/kl_div_loss_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/kl_div_loss_cpu_kernel.h new file mode 100644 index 00000000000..20c9b5bbe0f --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/kl_div_loss_cpu_kernel.h @@ -0,0 +1,77 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_KL_DIV_LOSS_CPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_KL_DIV_LOSS_CPU_KERNEL_H + +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class KLDivLossCpuKernelMod : public NativeCpuKernelMod { + public: + KLDivLossCpuKernelMod() {} + ~KLDivLossCpuKernelMod() override = default; + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &onHost = std::map()) override; + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + + template + bool LaunchNoneReduction(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + + template + bool LaunchOther(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + + bool CheckParams(); + + using KLDivLossFunc = std::function &, + const std::vector &, const std::vector &)>; + + private: + static std::vector> func_list_; + KLDivLossFunc kernel_func_; + std::string reductionMode_; + int64_t batch_size_; + size_t input_x_shape_size_{1}; + size_t input_target_shape_size_{1}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_KL_DIV_LOSS_CPU_KERNEL_H diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cu index 0d2d2c76166..4f69a780ea2 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cu @@ -390,6 +390,14 @@ template CUDA_LIB_EXPORT void KLDivLossGrad(const int &input_size, const const float *input_x, const float *input_y, const float *dloss, float *dx, float *dy, cudaStream_t stream); +template CUDA_LIB_EXPORT void KLDivLoss(const int &input_size, const ReductionMode &reduction, + const double *input_x, const double *input_y, double *loss, + double *tmp_loss, cudaStream_t stream); + +template CUDA_LIB_EXPORT void KLDivLossGrad(const int &input_size, const ReductionMode &reduction, + const double *input_x, const double *input_y, const double *dloss, + double *dx, double *dy, cudaStream_t stream); + template CUDA_LIB_EXPORT void BinaryCrossEntropyLoss(const int &input_size, const ReductionMode &reduction, const float *input_x, const float *input_y, const float *weight, float *loss, float *tmp_loss, diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/kl_div_loss_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/kl_div_loss_gpu_kernel.cc index 4ec5133e3aa..b949418b74d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/kl_div_loss_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/kl_div_loss_gpu_kernel.cc @@ -26,5 +26,9 @@ MS_REG_GPU_KERNEL_ONE( KLDivLoss, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), KLDivLossGpuKernelMod, half) +MS_REG_GPU_KERNEL_ONE( + KLDivLoss, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + KLDivLossGpuKernelMod, double) } // namespace kernel } // namespace mindspore diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index 96dac6704d3..62d7c4c1089 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -830,6 +830,7 @@ GVAR_DEF(PrimitivePtr, kPrimZeta, std::make_shared("Zeta")); GVAR_DEF(PrimitivePtr, kPrimIgamma, std::make_shared("Igamma")); GVAR_DEF(PrimitivePtr, kPrimIgammac, std::make_shared("Igammac")); GVAR_DEF(PrimitivePtr, kPrimIgammaGradA, std::make_shared("IgammaGradA")); +GVAR_DEF(PrimitivePtr, kPrimKLDivLoss, std::make_shared("KLDivLoss")); // linalg GVAR_DEF(PrimitivePtr, kPrimSvd, std::make_shared("Svd")); diff --git a/mindspore/core/ops/kl_div_loss.cc b/mindspore/core/ops/kl_div_loss.cc new file mode 100644 index 00000000000..d4d6e949085 --- /dev/null +++ b/mindspore/core/ops/kl_div_loss.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/kl_div_loss.h" +#include +#include +#include "mindapi/ir/type.h" +#include "utils/check_convert_utils.h" +#include "ops/op_utils.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +void KLDivLoss::Init(const std::string &reduction) { set_reduction(reduction); } + +void KLDivLoss::set_reduction(const std::string &reduction) { + (void)this->AddAttr(kReduction, api::MakeValue(reduction)); +} + +std::string KLDivLoss::get_reduction() const { return GetValue(GetAttr(ops::kReduction)); } + +abstract::ShapePtr KLDivLossInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + auto op_name = primitive->name(); + auto input_x_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape()); + auto input_x_shape = input_x_map[kShape]; + auto input_target_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape()); + auto input_target_shape = input_target_map[kShape]; + CheckAndConvertUtils::Check("x shape", input_x_shape, kEqual, input_target_shape, op_name, ValueError); + + auto reduction = GetValue(primitive->GetAttr(kReduction)); + if (reduction == kNone) { + return std::make_shared(input_x_shape); + } + + if (reduction == kBatchMean && input_x_shape.size() == 0) { + MS_LOG(EXCEPTION) << "For " << op_name << ", can not do batchmean with x shape = []"; + } + + std::vector y_shape; + y_shape.resize(0); + return std::make_shared(y_shape); +} + +TypePtr KLDivLossInferType(const PrimitivePtr &prim, const std::vector &input_args) { + auto op_name = prim->name(); + const std::set valid_types = {kFloat16, kFloat32, kFloat64}; + auto input_x_type = input_args[kInputIndex0]->BuildType(); + auto input_target_type = input_args[kInputIndex1]->BuildType(); + CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, valid_types, op_name); + + std::map types; + types.emplace("x", input_x_type); + types.emplace("target", input_target_type); + CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name); + return input_x_type; +} + +MIND_API_OPERATOR_IMPL(KLDivLoss, BaseOperator); +AbstractBasePtr KLDivLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const int64_t kInputsNum = 2; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name()); + auto infer_shape = KLDivLossInferShape(primitive, input_args); + auto infer_type = KLDivLossInferType(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(KLDivLoss, prim::kPrimKLDivLoss, KLDivLossInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/kl_div_loss.h b/mindspore/core/ops/kl_div_loss.h new file mode 100644 index 00000000000..dc7eb558bef --- /dev/null +++ b/mindspore/core/ops/kl_div_loss.h @@ -0,0 +1,52 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_KL_DIV_LOSS_H +#define MINDSPORE_CORE_OPS_KL_DIV_LOSS_H + +#include +#include +#include +#include "ops/base_operator.h" +#include "ops/op_name.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameKLDivLoss = "KLDivLoss"; +/// \brief Returns the singular value decompositions of one or more matrices. +/// Refer to Python API @ref mindspore.ops.svd for more details. +class MIND_API KLDivLoss : public BaseOperator { + public: + MIND_API_BASE_MEMBER(KLDivLoss); + /// \brief Constructor. + KLDivLoss() : BaseOperator(kNameKLDivLoss) { InitIOName({"x", "target"}, {"y"}); } + explicit KLDivLoss(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x", "target"}, {"y"}); } + /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.KLDiv for the inputs. + void Init(const std::string &reduction = kMean); + /// \brief Set reduction. + void set_reduction(const std::string &reduction); + /// \brief Get reduction. + /// + /// \return reduction. + std::string get_reduction() const; +}; + +abstract::AbstractBasePtr KLDivLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_KL_DIV_LOSS_H diff --git a/mindspore/core/ops/op_name.h b/mindspore/core/ops/op_name.h index 244a08ec2b3..90f8f4078d1 100644 --- a/mindspore/core/ops/op_name.h +++ b/mindspore/core/ops/op_name.h @@ -264,7 +264,10 @@ constexpr auto kIsOriginalPadMode = "is_original_pad_mode"; constexpr auto kOriginalOpName = "original_op_name"; constexpr auto kSymmetric = "symmetric"; constexpr auto kDstType = "dst_type"; +constexpr auto kNone = "none"; constexpr auto kMean = "mean"; +constexpr auto kBatchMean = "batchmean"; +constexpr auto kSum = "sum"; constexpr auto kIndices = "indices"; constexpr auto kBegin = "begin"; constexpr auto kSrcFormat = "src_format"; diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index 9b6b762f573..5e34cabd554 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -5126,7 +5126,7 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer): return var_dtype, accum_dtype -class KLDivLoss(PrimitiveWithInfer): +class KLDivLoss(Primitive): r""" Computes the Kullback-Leibler divergence between the logits and the labels. @@ -5134,27 +5134,28 @@ class KLDivLoss(PrimitiveWithInfer): .. math:: L = \{l_1,\dots,l_N\}^\top, \quad - l_n = y_n \cdot (\log y_n - x_n) + l_n = target_n \cdot (\log target_n - x_n) Then, .. math:: - \ell(x, y) = \begin{cases} + \ell(x, target) = \begin{cases} L, & \text{if reduction} = \text{'none';}\\ \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ + \operatorname{batchmean}(L), & \text{if reduction} = \text{'batchmean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} \end{cases} where :math:`x` represents `logits`. - :math:`y` represents `labels`. - :math:`\ell(x, y)` represents `output`. + :math:`target` represents `labels`. + :math:`\ell(x, target)` represents `output`. Args: reduction (str): Specifies the reduction to be applied to the output. - Its value must be one of 'none', 'mean' or 'sum'. Default: 'mean'. + Its value must be one of 'none', 'mean', 'batchmean' or 'sum'. Default: 'mean'. Inputs: - - **logits** (Tensor) - The input Tensor. The data type must be float32. + - **logits** (Tensor) - The input Tensor. The data type must be float16, float32 or float64. - **labels** (Tensor) - The label Tensor which has the same shape and data type as `logits`. Outputs: @@ -5167,7 +5168,7 @@ class KLDivLoss(PrimitiveWithInfer): TypeError: If dtype of `logits` or `labels` is not float32. Supported Platforms: - ``GPU`` + ``CPU`` ``GPU`` Examples: >>> class Net(nn.Cell): @@ -5189,21 +5190,17 @@ class KLDivLoss(PrimitiveWithInfer): @prim_attr_register def __init__(self, reduction='mean'): """Initialize KLDivLoss.""" - self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name) - - def infer_shape(self, x_shape, y_shape): - validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) - if self.reduction in ('mean', 'sum'): - shape = [] + device_target = context.get_context("device_target") + if device_target == "CPU": + support_mode = ['none', 'mean', 'batchmean', 'sum'] + elif device_target == "GPU": + support_mode = ['none', 'mean', 'sum'] + elif device_target == "Ascend": + raise ValueError(f"'{self.name}' does not support Ascend platform currently.") else: - shape = x_shape - return shape + raise ValueError(f"'{self.name}' unknown device target: '{device_target}'") - def infer_dtype(self, x_type, y_type): - args = {'x': x_type, 'y': y_type} - valid_dtypes = (mstype.float16, mstype.float32) - validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) - return x_type + self.reduction = validator.check_string(reduction, support_mode, 'reduction', self.name) class BinaryCrossEntropy(Primitive): diff --git a/tests/st/ops/cpu/test_kl_div_op.py b/tests/st/ops/cpu/test_kl_div_op.py new file mode 100644 index 00000000000..36b6f7e5c62 --- /dev/null +++ b/tests/st/ops/cpu/test_kl_div_op.py @@ -0,0 +1,109 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class Net(nn.Cell): + def __init__(self, reduction="none"): + super(Net, self).__init__() + self.kl_div_loss = P.KLDivLoss(reduction) + + def construct(self, x, y): + return self.kl_div_loss(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) +def test_mode_none_and_dtype_with_static_input(dtype): + """ + Feature: test none mode with different input dtype. + Description: input with negative elements. + Expectation: success. + """ + np.random.seed(42) + prediction = mindspore.Tensor(np.log(np.array([[0.3, 0.7], [0.5, 0.5]])).astype(dtype)) + target = mindspore.Tensor(np.array([[-1, 1], [1, -1]]).astype(dtype)) + net = Net("none") + loss = net(Tensor(prediction), Tensor(target)) + expect = np.array([[0, 0.35667494], [0.69314718, 0]]).astype(dtype) + assert np.allclose(loss.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) +def test_mode_mean_and_dtype_with_static_input(dtype): + """ + Feature: test mean mode with different input dtype. + Description: input with negative elements. + Expectation: success. + """ + np.random.seed(42) + prediction = mindspore.Tensor(np.log(np.array([[0.3, 0.7], [0.5, 0.5]])).astype(dtype)) + target = mindspore.Tensor(np.array([[-1, 1], [1, -1]]).astype(dtype)) + net = Net("mean") + loss = net(Tensor(prediction), Tensor(target)) + expect = np.array([0.26245553]).astype(dtype) + assert np.allclose(loss.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) +def test_mode_sum_and_dtype_with_static_input(dtype): + """ + Feature: test sum mode with different input dtype. + Description: input with negative elements. + Expectation: success. + """ + np.random.seed(42) + prediction = mindspore.Tensor(np.log(np.array([[0.3, 0.7], [0.5, 0.5]])).astype(dtype)) + target = mindspore.Tensor(np.array([[-1, 1], [1, -1]]).astype(dtype)) + net = Net("sum") + loss = net(Tensor(prediction), Tensor(target)) + expect = np.array([1.04982212]).astype(dtype) + assert np.allclose(loss.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) +def test_mode_batchmean_and_dtype_with_static_input(dtype): + """ + Feature: test batchmean mode with different input dtype. + Description: input with negative elements. + Expectation: success. + """ + np.random.seed(42) + prediction = mindspore.Tensor(np.log(np.array([[0.3, 0.7], [0.5, 0.5]])).astype(dtype)) + target = mindspore.Tensor(np.array([[-1, 1], [1, -1]]).astype(dtype)) + net = Net("batchmean") + loss = net(Tensor(prediction), Tensor(target)) + expect = np.array([0.52491106]).astype(dtype) + assert np.allclose(loss.asnumpy(), expect) diff --git a/tests/ut/python/parallel/test_kldiv_loss.py b/tests/ut/python/parallel/test_kldiv_loss.py index 42ef7b08bcb..4ecc53f38ee 100644 --- a/tests/ut/python/parallel/test_kldiv_loss.py +++ b/tests/ut/python/parallel/test_kldiv_loss.py @@ -13,7 +13,6 @@ # limitations under the License. import numpy as np - import mindspore.common.dtype as mstype from mindspore import Tensor, context from mindspore.nn import Cell @@ -41,6 +40,7 @@ def test_kldiv_loss_mean_auto_parallel(): Description: auto parallel, reduction is 'mean' Expectation: compile success """ + context.set_context(device_target="GPU") context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True) reduction = 'mean' net = Net(reduction) @@ -53,6 +53,7 @@ def test_kldiv_loss_none_auto_parallel(): Description: auto parallel, reduction is 'none' Expectation: compile success """ + context.set_context(device_target="GPU") context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True) reduction = 'none' net = Net(reduction) @@ -65,6 +66,7 @@ def test_kldiv_loss_sum_auto_parallel(): Description: auto parallel, reduction is 'sum' Expectation: compile success """ + context.set_context(device_target="GPU") context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True) reduction = 'sum' net = Net(reduction) @@ -77,6 +79,7 @@ def test_kldiv_loss_mean_data_parallel(): Description: data parallel, reduction is 'mean' Expectation: compile success """ + context.set_context(device_target="GPU") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1) reduction = 'mean' net = Net(reduction) @@ -92,6 +95,7 @@ def test_kldiv_loss_none_data_parallel(): Description: data parallel, reduction is 'none' Expectation: compile success """ + context.set_context(device_target="GPU") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1) reduction = 'none' net = Net(reduction) @@ -104,6 +108,7 @@ def test_kldiv_loss_none_model_parallel(): Description: model parallel, reduction is 'none' Expectation: compile success """ + context.set_context(device_target="GPU") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5) reduction = 'none' strategy = ((2, 2), (2, 2)) @@ -117,6 +122,7 @@ def test_kldiv_loss_mean_model_parallel(): Description: model parallel, reduction is 'mean' Expectation: compile success """ + context.set_context(device_target="GPU") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5) reduction = 'mean' strategy = ((4, 2), (4, 2))