!34511 Add CPU support for KLDivLoss op

Merge pull request !34511 from zhuyuxiao/I51VMV
This commit is contained in:
i-robot 2022-05-23 11:06:37 +00:00 committed by Gitee
commit 4f0b1dc60e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 563 additions and 22 deletions

View File

@ -0,0 +1,201 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/kl_div_loss_cpu_kernel.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <map>
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "mindspore/core/ops/kl_div_loss.h"
#include "include/common/thread_pool.h"
namespace mindspore {
namespace kernel {
const size_t kMyAddInputsNum = 2;
const size_t kMyAddOutputsNum = 1;
bool KLDivLossCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::KLDivLoss>(base_operator);
if (!kernel_ptr) {
MS_LOG(EXCEPTION) << "cast KLDivLoss ops failed!";
return false;
}
kernel_name_ = kernel_ptr->name();
reductionMode_ = kernel_ptr->get_reduction();
batch_size_ = inputs[kIndex0]->GetShapeVector()[kIndex0];
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMyAddInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMyAddOutputsNum, kernel_name_);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
return true;
}
bool KLDivLossCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
MS_EXCEPTION_IF_NULL(kernel_func_);
return kernel_func_(this, inputs, workspace, outputs);
}
int KLDivLossCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &onHost) {
int ret = 0;
ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, onHost);
if (ret != 0) {
MS_LOG(WARNING) << kernel_name_ << " reinit failed.";
return ret;
}
std::vector<int64_t> input_x_shape = inputs[kIndex0]->GetShapeVector();
for (size_t i = 0; i < input_x_shape.size(); ++i) {
input_x_shape_size_ *= input_x_shape[i];
}
std::vector<int64_t> input_target_shape = inputs[kIndex1]->GetShapeVector();
for (size_t i = 0; i < input_target_shape.size(); ++i) {
input_target_shape_size_ *= input_target_shape[i];
}
if (reductionMode_ != ops::kNone) {
size_t type_size = GetTypeByte(TypeIdToType(inputs[kIndex0]->GetDtype()));
workspace_size_list_.push_back(input_x_shape_size_ * type_size);
}
return ret;
}
std::vector<KernelAttr> KLDivLossCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list;
(void)std::transform(
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, KLDivLossCpuKernelMod::KLDivLossFunc> &pair) { return pair.first; });
return support_list;
}
bool KLDivLossCpuKernelMod::CheckParams() {
// for kl div, shape size of input 0 and input 1 must be the same
if (input_target_shape_size_ != input_x_shape_size_) {
MS_LOG(ERROR) << kernel_name_ << ": input x shape size = " << input_x_shape_size_
<< ", input target shape size = " << input_target_shape_size_ << ". They are not the same.";
return false;
}
return true;
}
template <typename T>
bool KLDivLossCpuKernelMod::LaunchNoneReduction(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
auto *input_x = reinterpret_cast<T *>(inputs[kIndex0]->addr);
auto *input_target = reinterpret_cast<T *>(inputs[kIndex1]->addr);
auto *y = reinterpret_cast<T *>(outputs[kIndex0]->addr);
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_x(input_x, input_x_shape_size_, 1);
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_target(input_target, input_target_shape_size_, 1);
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_y(y, input_x_shape_size_, 1);
array_y = array_target * (Eigen::log(array_target) - array_x);
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
if (std::isnan(static_cast<float>(array_y[i]))) {
array_y[i] = static_cast<T>(0);
}
}
};
ParallelLaunchAutoSearch(task, input_x_shape_size_, this, &parallel_search_info_);
return true;
}
template <typename T>
bool KLDivLossCpuKernelMod::LaunchOther(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
auto *input_x = reinterpret_cast<T *>(inputs[kIndex0]->addr);
auto *input_target = reinterpret_cast<T *>(inputs[kIndex1]->addr);
auto *tmp_result = reinterpret_cast<T *>(workspace[kIndex0]->addr);
auto *y = reinterpret_cast<T *>(outputs[kIndex0]->addr);
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_x(input_x, input_x_shape_size_, 1);
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_target(input_target, input_target_shape_size_, 1);
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_tmp(tmp_result, input_x_shape_size_, 1);
array_tmp = array_target * (Eigen::log(array_target) - array_x);
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
if (std::isnan(static_cast<float>(array_tmp[i]))) {
array_tmp[i] = static_cast<T>(0);
}
}
};
ParallelLaunchAutoSearch(task, input_x_shape_size_, this, &parallel_search_info_);
if (reductionMode_ == ops::kSum) {
y[kIndex0] = array_tmp.sum();
return true;
}
if (reductionMode_ == ops::kMean) {
y[kIndex0] = array_tmp.mean();
return true;
}
if (reductionMode_ == ops::kBatchMean) {
y[kIndex0] = array_tmp.sum() / static_cast<T>(batch_size_);
return true;
}
return false;
}
template <typename T>
bool KLDivLossCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMyAddInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMyAddOutputsNum, kernel_name_);
if (!KLDivLossCpuKernelMod::CheckParams()) {
MS_LOG(EXCEPTION) << kernel_name_ << ": check param failed.";
}
if (reductionMode_ == ops::kNone) {
return LaunchNoneReduction<T>(inputs, workspace, outputs);
}
return LaunchOther<T>(inputs, workspace, outputs);
}
std::vector<std::pair<KernelAttr, KLDivLossCpuKernelMod::KLDivLossFunc>> KLDivLossCpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&KLDivLossCpuKernelMod::LaunchKernel<Eigen::half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&KLDivLossCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&KLDivLossCpuKernelMod::LaunchKernel<double>}};
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, KLDivLoss, KLDivLossCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,77 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_KL_DIV_LOSS_CPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_KL_DIV_LOSS_CPU_KERNEL_H
#include <vector>
#include <memory>
#include <utility>
#include <map>
#include <string>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class KLDivLossCpuKernelMod : public NativeCpuKernelMod {
public:
KLDivLossCpuKernelMod() {}
~KLDivLossCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &onHost = std::map<uint32_t, tensor::TensorPtr>()) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
template <typename T>
bool LaunchNoneReduction(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
template <typename T>
bool LaunchOther(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
bool CheckParams();
using KLDivLossFunc = std::function<bool(KLDivLossCpuKernelMod *, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;
private:
static std::vector<std::pair<KernelAttr, KLDivLossFunc>> func_list_;
KLDivLossFunc kernel_func_;
std::string reductionMode_;
int64_t batch_size_;
size_t input_x_shape_size_{1};
size_t input_target_shape_size_{1};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_KL_DIV_LOSS_CPU_KERNEL_H

View File

@ -390,6 +390,14 @@ template CUDA_LIB_EXPORT void KLDivLossGrad<float>(const int &input_size, const
const float *input_x, const float *input_y, const float *dloss,
float *dx, float *dy, cudaStream_t stream);
template CUDA_LIB_EXPORT void KLDivLoss<double>(const int &input_size, const ReductionMode &reduction,
const double *input_x, const double *input_y, double *loss,
double *tmp_loss, cudaStream_t stream);
template CUDA_LIB_EXPORT void KLDivLossGrad<double>(const int &input_size, const ReductionMode &reduction,
const double *input_x, const double *input_y, const double *dloss,
double *dx, double *dy, cudaStream_t stream);
template CUDA_LIB_EXPORT void BinaryCrossEntropyLoss<float>(const int &input_size, const ReductionMode &reduction,
const float *input_x, const float *input_y,
const float *weight, float *loss, float *tmp_loss,

View File

@ -26,5 +26,9 @@ MS_REG_GPU_KERNEL_ONE(
KLDivLoss,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KLDivLossGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
KLDivLoss,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
KLDivLossGpuKernelMod, double)
} // namespace kernel
} // namespace mindspore

View File

@ -847,6 +847,7 @@ GVAR_DEF(PrimitivePtr, kPrimZeta, std::make_shared<Primitive>("Zeta"));
GVAR_DEF(PrimitivePtr, kPrimIgamma, std::make_shared<Primitive>("Igamma"));
GVAR_DEF(PrimitivePtr, kPrimIgammac, std::make_shared<Primitive>("Igammac"));
GVAR_DEF(PrimitivePtr, kPrimIgammaGradA, std::make_shared<Primitive>("IgammaGradA"));
GVAR_DEF(PrimitivePtr, kPrimKLDivLoss, std::make_shared<Primitive>("KLDivLoss"));
// linalg
GVAR_DEF(PrimitivePtr, kPrimSvd, std::make_shared<Primitive>("Svd"));

View File

@ -0,0 +1,83 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/kl_div_loss.h"
#include <map>
#include <set>
#include "mindapi/ir/type.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
void KLDivLoss::Init(const std::string &reduction) { set_reduction(reduction); }
void KLDivLoss::set_reduction(const std::string &reduction) {
(void)this->AddAttr(kReduction, api::MakeValue(reduction));
}
std::string KLDivLoss::get_reduction() const { return GetValue<std::string>(GetAttr(ops::kReduction)); }
abstract::ShapePtr KLDivLossInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
auto input_x_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape());
auto input_x_shape = input_x_map[kShape];
auto input_target_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape());
auto input_target_shape = input_target_map[kShape];
CheckAndConvertUtils::Check("x shape", input_x_shape, kEqual, input_target_shape, op_name, ValueError);
auto reduction = GetValue<std::string>(primitive->GetAttr(kReduction));
if (reduction == kNone) {
return std::make_shared<abstract::Shape>(input_x_shape);
}
if (reduction == kBatchMean && input_x_shape.size() == 0) {
MS_LOG(EXCEPTION) << "For " << op_name << ", can not do batchmean with x shape = []";
}
std::vector<std::int64_t> y_shape;
y_shape.resize(0);
return std::make_shared<abstract::Shape>(y_shape);
}
TypePtr KLDivLossInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = prim->name();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
auto input_x_type = input_args[kInputIndex0]->BuildType();
auto input_target_type = input_args[kInputIndex1]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, valid_types, op_name);
std::map<std::string, TypePtr> types;
types.emplace("x", input_x_type);
types.emplace("target", input_target_type);
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name);
return input_x_type;
}
MIND_API_OPERATOR_IMPL(KLDivLoss, BaseOperator);
AbstractBasePtr KLDivLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputsNum = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
auto infer_shape = KLDivLossInferShape(primitive, input_args);
auto infer_type = KLDivLossInferType(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(KLDivLoss, prim::kPrimKLDivLoss, KLDivLossInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,52 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_KL_DIV_LOSS_H
#define MINDSPORE_CORE_OPS_KL_DIV_LOSS_H
#include <string>
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "ops/op_name.h"
namespace mindspore {
namespace ops {
constexpr auto kNameKLDivLoss = "KLDivLoss";
/// \brief Returns the singular value decompositions of one or more matrices.
/// Refer to Python API @ref mindspore.ops.svd for more details.
class MIND_API KLDivLoss : public BaseOperator {
public:
MIND_API_BASE_MEMBER(KLDivLoss);
/// \brief Constructor.
KLDivLoss() : BaseOperator(kNameKLDivLoss) { InitIOName({"x", "target"}, {"y"}); }
explicit KLDivLoss(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x", "target"}, {"y"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.KLDiv for the inputs.
void Init(const std::string &reduction = kMean);
/// \brief Set reduction.
void set_reduction(const std::string &reduction);
/// \brief Get reduction.
///
/// \return reduction.
std::string get_reduction() const;
};
abstract::AbstractBasePtr KLDivLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_KL_DIV_LOSS_H

View File

@ -265,7 +265,10 @@ constexpr auto kIsOriginalPadMode = "is_original_pad_mode";
constexpr auto kOriginalOpName = "original_op_name";
constexpr auto kSymmetric = "symmetric";
constexpr auto kDstType = "dst_type";
constexpr auto kNone = "none";
constexpr auto kMean = "mean";
constexpr auto kBatchMean = "batchmean";
constexpr auto kSum = "sum";
constexpr auto kIndices = "indices";
constexpr auto kBegin = "begin";
constexpr auto kSrcFormat = "src_format";

View File

@ -5221,7 +5221,7 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer):
return var_dtype, accum_dtype
class KLDivLoss(PrimitiveWithInfer):
class KLDivLoss(Primitive):
r"""
Computes the Kullback-Leibler divergence between the logits and the labels.
@ -5229,27 +5229,28 @@ class KLDivLoss(PrimitiveWithInfer):
.. math::
L = \{l_1,\dots,l_N\}^\top, \quad
l_n = y_n \cdot (\log y_n - x_n)
l_n = target_n \cdot (\log target_n - x_n)
Then,
.. math::
\ell(x, y) = \begin{cases}
\ell(x, target) = \begin{cases}
L, & \text{if reduction} = \text{'none';}\\
\operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
\operatorname{batchmean}(L), & \text{if reduction} = \text{'batchmean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
\end{cases}
where :math:`x` represents `logits`.
:math:`y` represents `labels`.
:math:`\ell(x, y)` represents `output`.
:math:`target` represents `labels`.
:math:`\ell(x, target)` represents `output`.
Args:
reduction (str): Specifies the reduction to be applied to the output.
Its value must be one of 'none', 'mean' or 'sum'. Default: 'mean'.
Its value must be one of 'none', 'mean', 'batchmean' or 'sum'. Default: 'mean'.
Inputs:
- **logits** (Tensor) - The input Tensor. The data type must be float32.
- **logits** (Tensor) - The input Tensor. The data type must be float16, float32 or float64.
- **labels** (Tensor) - The label Tensor which has the same shape and data type as `logits`.
Outputs:
@ -5262,7 +5263,7 @@ class KLDivLoss(PrimitiveWithInfer):
TypeError: If dtype of `logits` or `labels` is not float32.
Supported Platforms:
``GPU``
``CPU`` ``GPU``
Examples:
>>> class Net(nn.Cell):
@ -5284,21 +5285,17 @@ class KLDivLoss(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, reduction='mean'):
"""Initialize KLDivLoss."""
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
def infer_shape(self, x_shape, y_shape):
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
if self.reduction in ('mean', 'sum'):
shape = []
device_target = context.get_context("device_target")
if device_target == "CPU":
support_mode = ['none', 'mean', 'batchmean', 'sum']
elif device_target == "GPU":
support_mode = ['none', 'mean', 'sum']
elif device_target == "Ascend":
raise ValueError(f"'{self.name}' does not support Ascend platform currently.")
else:
shape = x_shape
return shape
raise ValueError(f"'{self.name}' unknown device target: '{device_target}'")
def infer_dtype(self, x_type, y_type):
args = {'x': x_type, 'y': y_type}
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
return x_type
self.reduction = validator.check_string(reduction, support_mode, 'reduction', self.name)
class BinaryCrossEntropy(Primitive):

View File

@ -0,0 +1,109 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class Net(nn.Cell):
def __init__(self, reduction="none"):
super(Net, self).__init__()
self.kl_div_loss = P.KLDivLoss(reduction)
def construct(self, x, y):
return self.kl_div_loss(x, y)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
def test_mode_none_and_dtype_with_static_input(dtype):
"""
Feature: test none mode with different input dtype.
Description: input with negative elements.
Expectation: success.
"""
np.random.seed(42)
prediction = mindspore.Tensor(np.log(np.array([[0.3, 0.7], [0.5, 0.5]])).astype(dtype))
target = mindspore.Tensor(np.array([[-1, 1], [1, -1]]).astype(dtype))
net = Net("none")
loss = net(Tensor(prediction), Tensor(target))
expect = np.array([[0, 0.35667494], [0.69314718, 0]]).astype(dtype)
assert np.allclose(loss.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
def test_mode_mean_and_dtype_with_static_input(dtype):
"""
Feature: test mean mode with different input dtype.
Description: input with negative elements.
Expectation: success.
"""
np.random.seed(42)
prediction = mindspore.Tensor(np.log(np.array([[0.3, 0.7], [0.5, 0.5]])).astype(dtype))
target = mindspore.Tensor(np.array([[-1, 1], [1, -1]]).astype(dtype))
net = Net("mean")
loss = net(Tensor(prediction), Tensor(target))
expect = np.array([0.26245553]).astype(dtype)
assert np.allclose(loss.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
def test_mode_sum_and_dtype_with_static_input(dtype):
"""
Feature: test sum mode with different input dtype.
Description: input with negative elements.
Expectation: success.
"""
np.random.seed(42)
prediction = mindspore.Tensor(np.log(np.array([[0.3, 0.7], [0.5, 0.5]])).astype(dtype))
target = mindspore.Tensor(np.array([[-1, 1], [1, -1]]).astype(dtype))
net = Net("sum")
loss = net(Tensor(prediction), Tensor(target))
expect = np.array([1.04982212]).astype(dtype)
assert np.allclose(loss.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
def test_mode_batchmean_and_dtype_with_static_input(dtype):
"""
Feature: test batchmean mode with different input dtype.
Description: input with negative elements.
Expectation: success.
"""
np.random.seed(42)
prediction = mindspore.Tensor(np.log(np.array([[0.3, 0.7], [0.5, 0.5]])).astype(dtype))
target = mindspore.Tensor(np.array([[-1, 1], [1, -1]]).astype(dtype))
net = Net("batchmean")
loss = net(Tensor(prediction), Tensor(target))
expect = np.array([0.52491106]).astype(dtype)
assert np.allclose(loss.asnumpy(), expect)

View File

@ -13,7 +13,6 @@
# limitations under the License.
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, context
from mindspore.nn import Cell
@ -41,6 +40,7 @@ def test_kldiv_loss_mean_auto_parallel():
Description: auto parallel, reduction is 'mean'
Expectation: compile success
"""
context.set_context(device_target="GPU")
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True)
reduction = 'mean'
net = Net(reduction)
@ -53,6 +53,7 @@ def test_kldiv_loss_none_auto_parallel():
Description: auto parallel, reduction is 'none'
Expectation: compile success
"""
context.set_context(device_target="GPU")
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True)
reduction = 'none'
net = Net(reduction)
@ -65,6 +66,7 @@ def test_kldiv_loss_sum_auto_parallel():
Description: auto parallel, reduction is 'sum'
Expectation: compile success
"""
context.set_context(device_target="GPU")
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True)
reduction = 'sum'
net = Net(reduction)
@ -77,6 +79,7 @@ def test_kldiv_loss_mean_data_parallel():
Description: data parallel, reduction is 'mean'
Expectation: compile success
"""
context.set_context(device_target="GPU")
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
reduction = 'mean'
net = Net(reduction)
@ -92,6 +95,7 @@ def test_kldiv_loss_none_data_parallel():
Description: data parallel, reduction is 'none'
Expectation: compile success
"""
context.set_context(device_target="GPU")
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
reduction = 'none'
net = Net(reduction)
@ -104,6 +108,7 @@ def test_kldiv_loss_none_model_parallel():
Description: model parallel, reduction is 'none'
Expectation: compile success
"""
context.set_context(device_target="GPU")
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5)
reduction = 'none'
strategy = ((2, 2), (2, 2))
@ -117,6 +122,7 @@ def test_kldiv_loss_mean_model_parallel():
Description: model parallel, reduction is 'mean'
Expectation: compile success
"""
context.set_context(device_target="GPU")
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5)
reduction = 'mean'
strategy = ((4, 2), (4, 2))