diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 213e04e351b..ad928d792f3 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -1745,11 +1745,14 @@ class NMSWithMask(PrimitiveWithInfer): self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask']) def infer_shape(self, bboxes_shape): + validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ) + validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT) + validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ) num = bboxes_shape[0] - validator.check_integer("bboxes_shape[0]", num, 0, Rel.GT) return (bboxes_shape, (num,), (num,)) def infer_dtype(self, bboxes_dtype): + validator.check_subclass("bboxes_dtype", bboxes_dtype, mstype.tensor) validator.check_typename("bboxes_dtype", bboxes_dtype, [mstype.float16, mstype.float32]) return (bboxes_dtype, mstype.int32, mstype.bool_)