nllloss input for logits only support 2D now
This commit is contained in:
parent
fb4a81add0
commit
9dcac9de4e
|
@ -38,4 +38,4 @@ mindspore.ops.NLLLoss
|
|||
|
||||
异常:
|
||||
- **TypeError** - `logits` 或 `weight` 的数据类型既不是float16也不是float32, `labels` 不是int32。
|
||||
- **ValueError** - `logits` 不是一维或二维Tensor, `labels` 和 `weight` 不是一维Tensor。 `logits` 是二维Tensor时, `logits` 的第一个维度不等于 `labels` , `logits` 的第二个维度不等于 `weight` 。 `logits` 是一维Tensor时, `logits` 、 `labels` 和 `weight` 的维度应该相同。
|
||||
- **ValueError** - `logits` 不是二维Tensor, `labels` 和 `weight` 不是一维Tensor。 `logits` 的第一个维度不等于 `labels` , `logits` 的第二个维度不等于 `weight` 。
|
|
@ -40,10 +40,11 @@ class NLLLossInfer : public abstract::OpInferBase {
|
|||
auto weight_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(weight_shape_ptr)[kShape];
|
||||
|
||||
const int64_t dims_2D = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of logits", SizeToLong(logits_shape.size()), kEqual, dims_2D,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of target", SizeToLong(target_shape.size()), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of weight", SizeToLong(weight_shape.size()), kEqual, 1, prim_name);
|
||||
CheckAndConvertUtils::CheckInRange("rank of logits", SizeToLong(logits_shape.size()), kIncludeBoth, {1, 2},
|
||||
prim_name);
|
||||
|
||||
if (!logits_shape_ptr->IsDynamic()) {
|
||||
if (!target_shape_ptr->IsDynamic() && logits_shape[kInputIndex0] != target_shape[kInputIndex0]) {
|
||||
|
|
Loading…
Reference in New Issue