!46490 bugfix, nllloss labels input must be in [0, C-1]

Merge pull request !46490 from zhangyanhui/develop_mas
This commit is contained in:
i-robot 2022-12-07 02:02:09 +00:00 committed by Gitee
commit ea7b2a377d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 12 additions and 1 deletions

View File

@ -39,3 +39,4 @@ mindspore.ops.NLLLoss
异常:
- **TypeError** - `logits``weight` 的数据类型既不是float16也不是float32 `labels` 不是int32。
- **ValueError** - `logits` 不是二维Tensor `labels``weight` 不是一维Tensor。 `logits` 的第一个维度不等于 `labels` `logits` 的第二个维度不等于 `weight`
- **ValueError** - `labels` 的取值超出 :math:`[0, C-1]` ,其中 :math:`C` 表示类的数量。

View File

@ -25,6 +25,7 @@ namespace kernel {
namespace {
constexpr size_t kNLLLossInputsNum = 3;
constexpr size_t kNLLLossOutputsNum = 2;
constexpr int minLabelNum = 0;
const std::map<Reduction, ReductionType> kReductionMap = {
{Reduction::MEAN, Reduction_Mean}, {Reduction::REDUCTION_SUM, Reduction_Sum}, {Reduction::NONE, Reduction_None}};
} // namespace
@ -82,6 +83,15 @@ bool NLLLossCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const auto *weight = reinterpret_cast<float *>(inputs[kIndex2]->addr);
auto *loss = reinterpret_cast<float *>(outputs[kIndex0]->addr);
auto *total_weight = reinterpret_cast<float *>(outputs[kIndex1]->addr);
if (logits == NULL || labels == NULL || weight == NULL) {
MS_LOG(EXCEPTION) << "Nllloss does not support null input";
}
for (int i = 0; i < nllloss_param_.batch_; i++) {
if (labels[i] < minLabelNum || labels[i] > nllloss_param_.class_num_) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the label must in scope[0, C-1], but got" << labels[i];
}
}
int ret = NLLLoss(logits, labels, weight, loss, total_weight, &nllloss_param_);
if (ret != static_cast<int>(NNACL_OK)) {