From 7b7a6a45a043f5e32d0daa4b82ef214ff70592f6 Mon Sep 17 00:00:00 2001 From: seatea Date: Mon, 30 Mar 2020 12:10:35 +0800 Subject: [PATCH] Check if the shape of the input of NMSWithMask is (N, 5). --- mindspore/ops/operations/math_ops.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 213e04e351b..ad928d792f3 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -1745,11 +1745,14 @@ class NMSWithMask(PrimitiveWithInfer): self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask']) def infer_shape(self, bboxes_shape): + validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ) + validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT) + validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ) num = bboxes_shape[0] - validator.check_integer("bboxes_shape[0]", num, 0, Rel.GT) return (bboxes_shape, (num,), (num,)) def infer_dtype(self, bboxes_dtype): + validator.check_subclass("bboxes_dtype", bboxes_dtype, mstype.tensor) validator.check_typename("bboxes_dtype", bboxes_dtype, [mstype.float16, mstype.float32]) return (bboxes_dtype, mstype.int32, mstype.bool_)