forked from mindspore-Ecosystem/mindspore
!34511 Add CPU support for KLDivLoss op
Merge pull request !34511 from zhuyuxiao/I51VMV
This commit is contained in:
commit
4f0b1dc60e
|
@ -0,0 +1,201 @@
|
|||
/**
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/kl_div_loss_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "mindspore/core/ops/kl_div_loss.h"
|
||||
#include "include/common/thread_pool.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
const size_t kMyAddInputsNum = 2;
|
||||
const size_t kMyAddOutputsNum = 1;
|
||||
|
||||
bool KLDivLossCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::KLDivLoss>(base_operator);
|
||||
if (!kernel_ptr) {
|
||||
MS_LOG(EXCEPTION) << "cast KLDivLoss ops failed!";
|
||||
return false;
|
||||
}
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
reductionMode_ = kernel_ptr->get_reduction();
|
||||
batch_size_ = inputs[kIndex0]->GetShapeVector()[kIndex0];
|
||||
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMyAddInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMyAddOutputsNum, kernel_name_);
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool KLDivLossCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_func_);
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
int KLDivLossCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &onHost) {
|
||||
int ret = 0;
|
||||
ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, onHost);
|
||||
if (ret != 0) {
|
||||
MS_LOG(WARNING) << kernel_name_ << " reinit failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<int64_t> input_x_shape = inputs[kIndex0]->GetShapeVector();
|
||||
for (size_t i = 0; i < input_x_shape.size(); ++i) {
|
||||
input_x_shape_size_ *= input_x_shape[i];
|
||||
}
|
||||
|
||||
std::vector<int64_t> input_target_shape = inputs[kIndex1]->GetShapeVector();
|
||||
for (size_t i = 0; i < input_target_shape.size(); ++i) {
|
||||
input_target_shape_size_ *= input_target_shape[i];
|
||||
}
|
||||
|
||||
if (reductionMode_ != ops::kNone) {
|
||||
size_t type_size = GetTypeByte(TypeIdToType(inputs[kIndex0]->GetDtype()));
|
||||
workspace_size_list_.push_back(input_x_shape_size_ * type_size);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> KLDivLossCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(
|
||||
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, KLDivLossCpuKernelMod::KLDivLossFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
bool KLDivLossCpuKernelMod::CheckParams() {
|
||||
// for kl div, shape size of input 0 and input 1 must be the same
|
||||
if (input_target_shape_size_ != input_x_shape_size_) {
|
||||
MS_LOG(ERROR) << kernel_name_ << ": input x shape size = " << input_x_shape_size_
|
||||
<< ", input target shape size = " << input_target_shape_size_ << ". They are not the same.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool KLDivLossCpuKernelMod::LaunchNoneReduction(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto *input_x = reinterpret_cast<T *>(inputs[kIndex0]->addr);
|
||||
auto *input_target = reinterpret_cast<T *>(inputs[kIndex1]->addr);
|
||||
auto *y = reinterpret_cast<T *>(outputs[kIndex0]->addr);
|
||||
|
||||
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_x(input_x, input_x_shape_size_, 1);
|
||||
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_target(input_target, input_target_shape_size_, 1);
|
||||
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_y(y, input_x_shape_size_, 1);
|
||||
|
||||
array_y = array_target * (Eigen::log(array_target) - array_x);
|
||||
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
if (std::isnan(static_cast<float>(array_y[i]))) {
|
||||
array_y[i] = static_cast<T>(0);
|
||||
}
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, input_x_shape_size_, this, ¶llel_search_info_);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool KLDivLossCpuKernelMod::LaunchOther(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto *input_x = reinterpret_cast<T *>(inputs[kIndex0]->addr);
|
||||
auto *input_target = reinterpret_cast<T *>(inputs[kIndex1]->addr);
|
||||
auto *tmp_result = reinterpret_cast<T *>(workspace[kIndex0]->addr);
|
||||
auto *y = reinterpret_cast<T *>(outputs[kIndex0]->addr);
|
||||
|
||||
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_x(input_x, input_x_shape_size_, 1);
|
||||
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_target(input_target, input_target_shape_size_, 1);
|
||||
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_tmp(tmp_result, input_x_shape_size_, 1);
|
||||
|
||||
array_tmp = array_target * (Eigen::log(array_target) - array_x);
|
||||
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
if (std::isnan(static_cast<float>(array_tmp[i]))) {
|
||||
array_tmp[i] = static_cast<T>(0);
|
||||
}
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, input_x_shape_size_, this, ¶llel_search_info_);
|
||||
|
||||
if (reductionMode_ == ops::kSum) {
|
||||
y[kIndex0] = array_tmp.sum();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (reductionMode_ == ops::kMean) {
|
||||
y[kIndex0] = array_tmp.mean();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (reductionMode_ == ops::kBatchMean) {
|
||||
y[kIndex0] = array_tmp.sum() / static_cast<T>(batch_size_);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool KLDivLossCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMyAddInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMyAddOutputsNum, kernel_name_);
|
||||
if (!KLDivLossCpuKernelMod::CheckParams()) {
|
||||
MS_LOG(EXCEPTION) << kernel_name_ << ": check param failed.";
|
||||
}
|
||||
|
||||
if (reductionMode_ == ops::kNone) {
|
||||
return LaunchNoneReduction<T>(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
return LaunchOther<T>(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, KLDivLossCpuKernelMod::KLDivLossFunc>> KLDivLossCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&KLDivLossCpuKernelMod::LaunchKernel<Eigen::half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&KLDivLossCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&KLDivLossCpuKernelMod::LaunchKernel<double>}};
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, KLDivLoss, KLDivLossCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_KL_DIV_LOSS_CPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_KL_DIV_LOSS_CPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class KLDivLossCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
KLDivLossCpuKernelMod() {}
|
||||
~KLDivLossCpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &onHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
||||
template <typename T>
|
||||
bool LaunchNoneReduction(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
||||
template <typename T>
|
||||
bool LaunchOther(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
||||
bool CheckParams();
|
||||
|
||||
using KLDivLossFunc = std::function<bool(KLDivLossCpuKernelMod *, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;
|
||||
|
||||
private:
|
||||
static std::vector<std::pair<KernelAttr, KLDivLossFunc>> func_list_;
|
||||
KLDivLossFunc kernel_func_;
|
||||
std::string reductionMode_;
|
||||
int64_t batch_size_;
|
||||
size_t input_x_shape_size_{1};
|
||||
size_t input_target_shape_size_{1};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_KL_DIV_LOSS_CPU_KERNEL_H
|
|
@ -390,6 +390,14 @@ template CUDA_LIB_EXPORT void KLDivLossGrad<float>(const int &input_size, const
|
|||
const float *input_x, const float *input_y, const float *dloss,
|
||||
float *dx, float *dy, cudaStream_t stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void KLDivLoss<double>(const int &input_size, const ReductionMode &reduction,
|
||||
const double *input_x, const double *input_y, double *loss,
|
||||
double *tmp_loss, cudaStream_t stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void KLDivLossGrad<double>(const int &input_size, const ReductionMode &reduction,
|
||||
const double *input_x, const double *input_y, const double *dloss,
|
||||
double *dx, double *dy, cudaStream_t stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void BinaryCrossEntropyLoss<float>(const int &input_size, const ReductionMode &reduction,
|
||||
const float *input_x, const float *input_y,
|
||||
const float *weight, float *loss, float *tmp_loss,
|
||||
|
|
|
@ -26,5 +26,9 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
KLDivLoss,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
KLDivLossGpuKernelMod, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
KLDivLoss,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
KLDivLossGpuKernelMod, double)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -847,6 +847,7 @@ GVAR_DEF(PrimitivePtr, kPrimZeta, std::make_shared<Primitive>("Zeta"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimIgamma, std::make_shared<Primitive>("Igamma"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimIgammac, std::make_shared<Primitive>("Igammac"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimIgammaGradA, std::make_shared<Primitive>("IgammaGradA"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimKLDivLoss, std::make_shared<Primitive>("KLDivLoss"));
|
||||
|
||||
// linalg
|
||||
GVAR_DEF(PrimitivePtr, kPrimSvd, std::make_shared<Primitive>("Svd"));
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/kl_div_loss.h"
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include "mindapi/ir/type.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void KLDivLoss::Init(const std::string &reduction) { set_reduction(reduction); }
|
||||
|
||||
void KLDivLoss::set_reduction(const std::string &reduction) {
|
||||
(void)this->AddAttr(kReduction, api::MakeValue(reduction));
|
||||
}
|
||||
|
||||
std::string KLDivLoss::get_reduction() const { return GetValue<std::string>(GetAttr(ops::kReduction)); }
|
||||
|
||||
abstract::ShapePtr KLDivLossInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
auto input_x_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape());
|
||||
auto input_x_shape = input_x_map[kShape];
|
||||
auto input_target_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape());
|
||||
auto input_target_shape = input_target_map[kShape];
|
||||
CheckAndConvertUtils::Check("x shape", input_x_shape, kEqual, input_target_shape, op_name, ValueError);
|
||||
|
||||
auto reduction = GetValue<std::string>(primitive->GetAttr(kReduction));
|
||||
if (reduction == kNone) {
|
||||
return std::make_shared<abstract::Shape>(input_x_shape);
|
||||
}
|
||||
|
||||
if (reduction == kBatchMean && input_x_shape.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "For " << op_name << ", can not do batchmean with x shape = []";
|
||||
}
|
||||
|
||||
std::vector<std::int64_t> y_shape;
|
||||
y_shape.resize(0);
|
||||
return std::make_shared<abstract::Shape>(y_shape);
|
||||
}
|
||||
|
||||
TypePtr KLDivLossInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = prim->name();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
auto input_x_type = input_args[kInputIndex0]->BuildType();
|
||||
auto input_target_type = input_args[kInputIndex1]->BuildType();
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, valid_types, op_name);
|
||||
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_x_type);
|
||||
types.emplace("target", input_target_type);
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name);
|
||||
return input_x_type;
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(KLDivLoss, BaseOperator);
|
||||
AbstractBasePtr KLDivLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputsNum = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||
auto infer_shape = KLDivLossInferShape(primitive, input_args);
|
||||
auto infer_type = KLDivLossInferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(KLDivLoss, prim::kPrimKLDivLoss, KLDivLossInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_KL_DIV_LOSS_H
|
||||
#define MINDSPORE_CORE_OPS_KL_DIV_LOSS_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "ops/op_name.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameKLDivLoss = "KLDivLoss";
|
||||
/// \brief Returns the singular value decompositions of one or more matrices.
|
||||
/// Refer to Python API @ref mindspore.ops.svd for more details.
|
||||
class MIND_API KLDivLoss : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(KLDivLoss);
|
||||
/// \brief Constructor.
|
||||
KLDivLoss() : BaseOperator(kNameKLDivLoss) { InitIOName({"x", "target"}, {"y"}); }
|
||||
explicit KLDivLoss(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x", "target"}, {"y"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.KLDiv for the inputs.
|
||||
void Init(const std::string &reduction = kMean);
|
||||
/// \brief Set reduction.
|
||||
void set_reduction(const std::string &reduction);
|
||||
/// \brief Get reduction.
|
||||
///
|
||||
/// \return reduction.
|
||||
std::string get_reduction() const;
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr KLDivLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_KL_DIV_LOSS_H
|
|
@ -265,7 +265,10 @@ constexpr auto kIsOriginalPadMode = "is_original_pad_mode";
|
|||
constexpr auto kOriginalOpName = "original_op_name";
|
||||
constexpr auto kSymmetric = "symmetric";
|
||||
constexpr auto kDstType = "dst_type";
|
||||
constexpr auto kNone = "none";
|
||||
constexpr auto kMean = "mean";
|
||||
constexpr auto kBatchMean = "batchmean";
|
||||
constexpr auto kSum = "sum";
|
||||
constexpr auto kIndices = "indices";
|
||||
constexpr auto kBegin = "begin";
|
||||
constexpr auto kSrcFormat = "src_format";
|
||||
|
|
|
@ -5221,7 +5221,7 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer):
|
|||
return var_dtype, accum_dtype
|
||||
|
||||
|
||||
class KLDivLoss(PrimitiveWithInfer):
|
||||
class KLDivLoss(Primitive):
|
||||
r"""
|
||||
Computes the Kullback-Leibler divergence between the logits and the labels.
|
||||
|
||||
|
@ -5229,27 +5229,28 @@ class KLDivLoss(PrimitiveWithInfer):
|
|||
|
||||
.. math::
|
||||
L = \{l_1,\dots,l_N\}^\top, \quad
|
||||
l_n = y_n \cdot (\log y_n - x_n)
|
||||
l_n = target_n \cdot (\log target_n - x_n)
|
||||
|
||||
Then,
|
||||
|
||||
.. math::
|
||||
\ell(x, y) = \begin{cases}
|
||||
\ell(x, target) = \begin{cases}
|
||||
L, & \text{if reduction} = \text{'none';}\\
|
||||
\operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
|
||||
\operatorname{batchmean}(L), & \text{if reduction} = \text{'batchmean';}\\
|
||||
\operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
|
||||
\end{cases}
|
||||
|
||||
where :math:`x` represents `logits`.
|
||||
:math:`y` represents `labels`.
|
||||
:math:`\ell(x, y)` represents `output`.
|
||||
:math:`target` represents `labels`.
|
||||
:math:`\ell(x, target)` represents `output`.
|
||||
|
||||
Args:
|
||||
reduction (str): Specifies the reduction to be applied to the output.
|
||||
Its value must be one of 'none', 'mean' or 'sum'. Default: 'mean'.
|
||||
Its value must be one of 'none', 'mean', 'batchmean' or 'sum'. Default: 'mean'.
|
||||
|
||||
Inputs:
|
||||
- **logits** (Tensor) - The input Tensor. The data type must be float32.
|
||||
- **logits** (Tensor) - The input Tensor. The data type must be float16, float32 or float64.
|
||||
- **labels** (Tensor) - The label Tensor which has the same shape and data type as `logits`.
|
||||
|
||||
Outputs:
|
||||
|
@ -5262,7 +5263,7 @@ class KLDivLoss(PrimitiveWithInfer):
|
|||
TypeError: If dtype of `logits` or `labels` is not float32.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
|
@ -5284,21 +5285,17 @@ class KLDivLoss(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, reduction='mean'):
|
||||
"""Initialize KLDivLoss."""
|
||||
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
|
||||
|
||||
def infer_shape(self, x_shape, y_shape):
|
||||
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
|
||||
if self.reduction in ('mean', 'sum'):
|
||||
shape = []
|
||||
device_target = context.get_context("device_target")
|
||||
if device_target == "CPU":
|
||||
support_mode = ['none', 'mean', 'batchmean', 'sum']
|
||||
elif device_target == "GPU":
|
||||
support_mode = ['none', 'mean', 'sum']
|
||||
elif device_target == "Ascend":
|
||||
raise ValueError(f"'{self.name}' does not support Ascend platform currently.")
|
||||
else:
|
||||
shape = x_shape
|
||||
return shape
|
||||
raise ValueError(f"'{self.name}' unknown device target: '{device_target}'")
|
||||
|
||||
def infer_dtype(self, x_type, y_type):
|
||||
args = {'x': x_type, 'y': y_type}
|
||||
valid_dtypes = (mstype.float16, mstype.float32)
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
|
||||
return x_type
|
||||
self.reduction = validator.check_string(reduction, support_mode, 'reduction', self.name)
|
||||
|
||||
|
||||
class BinaryCrossEntropy(Primitive):
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, reduction="none"):
|
||||
super(Net, self).__init__()
|
||||
self.kl_div_loss = P.KLDivLoss(reduction)
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.kl_div_loss(x, y)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
|
||||
def test_mode_none_and_dtype_with_static_input(dtype):
|
||||
"""
|
||||
Feature: test none mode with different input dtype.
|
||||
Description: input with negative elements.
|
||||
Expectation: success.
|
||||
"""
|
||||
np.random.seed(42)
|
||||
prediction = mindspore.Tensor(np.log(np.array([[0.3, 0.7], [0.5, 0.5]])).astype(dtype))
|
||||
target = mindspore.Tensor(np.array([[-1, 1], [1, -1]]).astype(dtype))
|
||||
net = Net("none")
|
||||
loss = net(Tensor(prediction), Tensor(target))
|
||||
expect = np.array([[0, 0.35667494], [0.69314718, 0]]).astype(dtype)
|
||||
assert np.allclose(loss.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
|
||||
def test_mode_mean_and_dtype_with_static_input(dtype):
|
||||
"""
|
||||
Feature: test mean mode with different input dtype.
|
||||
Description: input with negative elements.
|
||||
Expectation: success.
|
||||
"""
|
||||
np.random.seed(42)
|
||||
prediction = mindspore.Tensor(np.log(np.array([[0.3, 0.7], [0.5, 0.5]])).astype(dtype))
|
||||
target = mindspore.Tensor(np.array([[-1, 1], [1, -1]]).astype(dtype))
|
||||
net = Net("mean")
|
||||
loss = net(Tensor(prediction), Tensor(target))
|
||||
expect = np.array([0.26245553]).astype(dtype)
|
||||
assert np.allclose(loss.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
|
||||
def test_mode_sum_and_dtype_with_static_input(dtype):
|
||||
"""
|
||||
Feature: test sum mode with different input dtype.
|
||||
Description: input with negative elements.
|
||||
Expectation: success.
|
||||
"""
|
||||
np.random.seed(42)
|
||||
prediction = mindspore.Tensor(np.log(np.array([[0.3, 0.7], [0.5, 0.5]])).astype(dtype))
|
||||
target = mindspore.Tensor(np.array([[-1, 1], [1, -1]]).astype(dtype))
|
||||
net = Net("sum")
|
||||
loss = net(Tensor(prediction), Tensor(target))
|
||||
expect = np.array([1.04982212]).astype(dtype)
|
||||
assert np.allclose(loss.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
|
||||
def test_mode_batchmean_and_dtype_with_static_input(dtype):
|
||||
"""
|
||||
Feature: test batchmean mode with different input dtype.
|
||||
Description: input with negative elements.
|
||||
Expectation: success.
|
||||
"""
|
||||
np.random.seed(42)
|
||||
prediction = mindspore.Tensor(np.log(np.array([[0.3, 0.7], [0.5, 0.5]])).astype(dtype))
|
||||
target = mindspore.Tensor(np.array([[-1, 1], [1, -1]]).astype(dtype))
|
||||
net = Net("batchmean")
|
||||
loss = net(Tensor(prediction), Tensor(target))
|
||||
expect = np.array([0.52491106]).astype(dtype)
|
||||
assert np.allclose(loss.asnumpy(), expect)
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.nn import Cell
|
||||
|
@ -41,6 +40,7 @@ def test_kldiv_loss_mean_auto_parallel():
|
|||
Description: auto parallel, reduction is 'mean'
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_context(device_target="GPU")
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True)
|
||||
reduction = 'mean'
|
||||
net = Net(reduction)
|
||||
|
@ -53,6 +53,7 @@ def test_kldiv_loss_none_auto_parallel():
|
|||
Description: auto parallel, reduction is 'none'
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_context(device_target="GPU")
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True)
|
||||
reduction = 'none'
|
||||
net = Net(reduction)
|
||||
|
@ -65,6 +66,7 @@ def test_kldiv_loss_sum_auto_parallel():
|
|||
Description: auto parallel, reduction is 'sum'
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_context(device_target="GPU")
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True)
|
||||
reduction = 'sum'
|
||||
net = Net(reduction)
|
||||
|
@ -77,6 +79,7 @@ def test_kldiv_loss_mean_data_parallel():
|
|||
Description: data parallel, reduction is 'mean'
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_context(device_target="GPU")
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
|
||||
reduction = 'mean'
|
||||
net = Net(reduction)
|
||||
|
@ -92,6 +95,7 @@ def test_kldiv_loss_none_data_parallel():
|
|||
Description: data parallel, reduction is 'none'
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_context(device_target="GPU")
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
|
||||
reduction = 'none'
|
||||
net = Net(reduction)
|
||||
|
@ -104,6 +108,7 @@ def test_kldiv_loss_none_model_parallel():
|
|||
Description: model parallel, reduction is 'none'
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_context(device_target="GPU")
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5)
|
||||
reduction = 'none'
|
||||
strategy = ((2, 2), (2, 2))
|
||||
|
@ -117,6 +122,7 @@ def test_kldiv_loss_mean_model_parallel():
|
|||
Description: model parallel, reduction is 'mean'
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_context(device_target="GPU")
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5)
|
||||
reduction = 'mean'
|
||||
strategy = ((4, 2), (4, 2))
|
||||
|
|
Loading…
Reference in New Issue