forked from mindspore-Ecosystem/mindspore
!1763 add validator for InvertPermutation
Merge pull request !1763 from jiangjinsheng/issue_invert
This commit is contained in:
commit
584641180f
|
@ -1049,6 +1049,8 @@ class InvertPermutation(PrimitiveWithInfer):
|
|||
def __infer__(self, x):
|
||||
x_shp = x['shape']
|
||||
x_value = x['value']
|
||||
if x_value is None:
|
||||
raise ValueError(f'For \'{self.name}\' the input value must be const.')
|
||||
validator.check_value_type("shape", x_shp, [tuple, list], self.name)
|
||||
if mstype.issubclass_(x['dtype'], mstype.tensor):
|
||||
validator.check('x dimension', len(x_shp), '', 1, Rel.EQ, self.name)
|
||||
|
@ -1057,6 +1059,10 @@ class InvertPermutation(PrimitiveWithInfer):
|
|||
z = [x_value[i] for i in range(len(x_value))]
|
||||
z.sort()
|
||||
|
||||
validator.check(f'value length', len(x_value), f'unique value length', len(set(x_value)), Rel.EQ, self.name)
|
||||
validator.check(f'value min', min(x_value), '', 0, Rel.EQ, self.name)
|
||||
validator.check(f'value max', max(x_value), '', len(x_value)-1, Rel.EQ, self.name)
|
||||
|
||||
y = [None] * len(x_value)
|
||||
for i, value in enumerate(x_value):
|
||||
validator.check_value_type("input[%d]" % i, value, [int], self.name)
|
||||
|
|
Loading…
Reference in New Issue