forked from mindspore-Ecosystem/mindspore
!13 Check input shape for `NMSWithMask` op
Merge pull request !13 from seatea/NMSWithMask-check-shape
This commit is contained in:
commit
44cd0c1f90
|
@ -1745,11 +1745,14 @@ class NMSWithMask(PrimitiveWithInfer):
|
||||||
self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask'])
|
self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask'])
|
||||||
|
|
||||||
def infer_shape(self, bboxes_shape):
|
def infer_shape(self, bboxes_shape):
|
||||||
|
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ)
|
||||||
|
validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT)
|
||||||
|
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ)
|
||||||
num = bboxes_shape[0]
|
num = bboxes_shape[0]
|
||||||
validator.check_integer("bboxes_shape[0]", num, 0, Rel.GT)
|
|
||||||
return (bboxes_shape, (num,), (num,))
|
return (bboxes_shape, (num,), (num,))
|
||||||
|
|
||||||
def infer_dtype(self, bboxes_dtype):
|
def infer_dtype(self, bboxes_dtype):
|
||||||
|
validator.check_subclass("bboxes_dtype", bboxes_dtype, mstype.tensor)
|
||||||
validator.check_typename("bboxes_dtype", bboxes_dtype, [mstype.float16, mstype.float32])
|
validator.check_typename("bboxes_dtype", bboxes_dtype, [mstype.float16, mstype.float32])
|
||||||
return (bboxes_dtype, mstype.int32, mstype.bool_)
|
return (bboxes_dtype, mstype.int32, mstype.bool_)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue