forked from mindspore-Ecosystem/mindspore
!9474 SparseGatherV2: Raise error if tensor is provided as input for axis
From: @peilin-wang Reviewed-by: @tom__chen,@robingrosman Signed-off-by: @robingrosman
This commit is contained in:
commit
a5fd09448f
|
@ -807,7 +807,7 @@ class GatherV2(PrimitiveWithCheck):
|
|||
def __check__(self, params, indices, axis):
|
||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
|
||||
validator.check_subclass("axis", axis['dtype'], [mstype.tensor, mstype.int_], self.name)
|
||||
validator.check_subclass("axis", axis['dtype'], [mstype.int_], self.name)
|
||||
|
||||
|
||||
class SparseGatherV2(GatherV2):
|
||||
|
|
Loading…
Reference in New Issue