nllloss input for logits only support 2D now

This commit is contained in:
zhangyanhui 2022-11-29 20:23:30 +08:00
parent fb4a81add0
commit 9dcac9de4e
2 changed files with 4 additions and 3 deletions

View File

@ -38,4 +38,4 @@ mindspore.ops.NLLLoss
异常:
- **TypeError** - `logits``weight` 的数据类型既不是float16也不是float32 `labels` 不是int32。
- **ValueError** - `logits` 不是一维或二维Tensor `labels``weight` 不是一维Tensor。 `logits` 是二维Tensor时 `logits` 的第一个维度不等于 `labels` `logits` 的第二个维度不等于 `weight` `logits` 是一维Tensor时 `logits``labels``weight` 的维度应该相同。
- **ValueError** - `logits` 不是二维Tensor `labels``weight` 不是一维Tensor。 `logits` 的第一个维度不等于 `labels` `logits` 的第二个维度不等于 `weight`

View File

@ -40,10 +40,11 @@ class NLLLossInfer : public abstract::OpInferBase {
auto weight_shape_ptr = input_args[kInputIndex2]->BuildShape();
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(weight_shape_ptr)[kShape];
const int64_t dims_2D = 2;
(void)CheckAndConvertUtils::CheckInteger("rank of logits", SizeToLong(logits_shape.size()), kEqual, dims_2D,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("rank of target", SizeToLong(target_shape.size()), kEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckInteger("rank of weight", SizeToLong(weight_shape.size()), kEqual, 1, prim_name);
CheckAndConvertUtils::CheckInRange("rank of logits", SizeToLong(logits_shape.size()), kIncludeBoth, {1, 2},
prim_name);
if (!logits_shape_ptr->IsDynamic()) {
if (!target_shape_ptr->IsDynamic() && logits_shape[kInputIndex0] != target_shape[kInputIndex0]) {