fix argminwith value

This commit is contained in:
fangzehua 2020-07-25 09:52:23 +08:00
parent 0a74c8a52d
commit 1556ce86ea
1 changed files with 2 additions and 1 deletions

View File

@ -1232,7 +1232,8 @@ class ArgMinWithValue(PrimitiveWithInfer):
"""init ArgMinWithValue"""
self.axis = axis
self.keep_dims = keep_dims
_check_infer_attr_reduce(axis, keep_dims, self.name)
validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
validator.check_value_type('axis', axis, [int], self.name)
def infer_shape(self, x_shape):
axis = self.axis