forked from mindspore-Ecosystem/mindspore
!46490 bugfix, nllloss labels input must be in [0, C-1]
Merge pull request !46490 from zhangyanhui/develop_mas
This commit is contained in:
commit
ea7b2a377d
|
@ -39,3 +39,4 @@ mindspore.ops.NLLLoss
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `logits` 或 `weight` 的数据类型既不是float16也不是float32, `labels` 不是int32。
|
- **TypeError** - `logits` 或 `weight` 的数据类型既不是float16也不是float32, `labels` 不是int32。
|
||||||
- **ValueError** - `logits` 不是二维Tensor, `labels` 和 `weight` 不是一维Tensor。 `logits` 的第一个维度不等于 `labels` , `logits` 的第二个维度不等于 `weight` 。
|
- **ValueError** - `logits` 不是二维Tensor, `labels` 和 `weight` 不是一维Tensor。 `logits` 的第一个维度不等于 `labels` , `logits` 的第二个维度不等于 `weight` 。
|
||||||
|
- **ValueError** - `labels` 的取值超出 :math:`[0, C-1]` ,其中 :math:`C` 表示类的数量。
|
|
@ -25,6 +25,7 @@ namespace kernel {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t kNLLLossInputsNum = 3;
|
constexpr size_t kNLLLossInputsNum = 3;
|
||||||
constexpr size_t kNLLLossOutputsNum = 2;
|
constexpr size_t kNLLLossOutputsNum = 2;
|
||||||
|
constexpr int minLabelNum = 0;
|
||||||
const std::map<Reduction, ReductionType> kReductionMap = {
|
const std::map<Reduction, ReductionType> kReductionMap = {
|
||||||
{Reduction::MEAN, Reduction_Mean}, {Reduction::REDUCTION_SUM, Reduction_Sum}, {Reduction::NONE, Reduction_None}};
|
{Reduction::MEAN, Reduction_Mean}, {Reduction::REDUCTION_SUM, Reduction_Sum}, {Reduction::NONE, Reduction_None}};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -82,6 +83,15 @@ bool NLLLossCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
const auto *weight = reinterpret_cast<float *>(inputs[kIndex2]->addr);
|
const auto *weight = reinterpret_cast<float *>(inputs[kIndex2]->addr);
|
||||||
auto *loss = reinterpret_cast<float *>(outputs[kIndex0]->addr);
|
auto *loss = reinterpret_cast<float *>(outputs[kIndex0]->addr);
|
||||||
auto *total_weight = reinterpret_cast<float *>(outputs[kIndex1]->addr);
|
auto *total_weight = reinterpret_cast<float *>(outputs[kIndex1]->addr);
|
||||||
|
if (logits == NULL || labels == NULL || weight == NULL) {
|
||||||
|
MS_LOG(EXCEPTION) << "Nllloss does not support null input";
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < nllloss_param_.batch_; i++) {
|
||||||
|
if (labels[i] < minLabelNum || labels[i] > nllloss_param_.class_num_) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the label must in scope[0, C-1], but got" << labels[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int ret = NLLLoss(logits, labels, weight, loss, total_weight, &nllloss_param_);
|
int ret = NLLLoss(logits, labels, weight, loss, total_weight, &nllloss_param_);
|
||||||
if (ret != static_cast<int>(NNACL_OK)) {
|
if (ret != static_cast<int>(NNACL_OK)) {
|
||||||
|
|
Loading…
Reference in New Issue