This commit is contained in:
zong_shuai 2022-09-30 09:05:14 +08:00
parent cc66b9879e
commit 6bc07a5b58
1 changed files with 15 additions and 15 deletions

View File

@ -6304,23 +6304,22 @@ class _TensorScatterOp(PrimitiveWithInfer):
"""
def infer_shape(self, input_x_shape, indices_shape, updates_shape):
if len(indices_shape) < 2:
if indices_shape != [-2] and len(indices_shape) < 2:
raise ValueError(f"For '{self.name}', the dimension of 'indices' cannot be less than 2,"
f" but got {len(indices_shape)}.")
if indices_shape[-1] > len(input_x_shape):
raise ValueError(f"For '{self.name}', the last dimension of 'indices' must be less than or equal to "
f"the dimension of 'input_x', but got the "
f"last dimension of 'indices': {indices_shape[-1]} and the dimension of 'input_x': "
f"{len(input_x_shape)}.")
updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
if self._check_shape(updates_shape_check, updates_shape) is False:
raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to updates_shape_check, "
f"where updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] "
f"but got the shape of 'update': {updates_shape}, "
f"updates_shape_check: {updates_shape_check}, indices_shape: {indices_shape} and "
f"input_x_shape: {input_x_shape}. Please check input_x_shape and indices_shape.")
if indices_shape[-1] > 0:
if indices_shape[-1] > len(input_x_shape):
raise ValueError(f"For '{self.name}', the last dimension of 'indices' must be less than or equal to "
f"the dimension of 'input_x', but got the "
f"last dimension of 'indices': {indices_shape[-1]} and the dimension of 'input_x': "
f"{len(input_x_shape)}.")
updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
if self._check_shape(updates_shape_check, updates_shape) is False:
raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to updates_shape_check, "
f"where updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] "
f"but got the shape of 'update': {updates_shape}, "
f"updates_shape_check: {updates_shape_check}, indices_shape: {indices_shape} and "
f"input_x_shape: {input_x_shape}. Please check input_x_shape and indices_shape.")
return input_x_shape
@ -6328,6 +6327,7 @@ class _TensorScatterOp(PrimitiveWithInfer):
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name)
args = {"input_x": input_x_dtype, "updates": updates_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return input_x_dtype
def _check_shape(self, expect, real):