!9474 SparseGatherV2: Raise error if tensor is provided as input for axis

From: @peilin-wang
Reviewed-by: @tom__chen,@robingrosman
Signed-off-by: @robingrosman
This commit is contained in:
mindspore-ci-bot 2020-12-05 05:28:36 +08:00 committed by Gitee
commit a5fd09448f
1 changed files with 1 additions and 1 deletions

View File

@ -807,7 +807,7 @@ class GatherV2(PrimitiveWithCheck):
def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], [mstype.tensor, mstype.int_], self.name)
validator.check_subclass("axis", axis['dtype'], [mstype.int_], self.name)
class SparseGatherV2(GatherV2):