forked from mindspore-Ecosystem/mindspore
debug
This commit is contained in:
parent
cc66b9879e
commit
6bc07a5b58
|
@ -6304,23 +6304,22 @@ class _TensorScatterOp(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
def infer_shape(self, input_x_shape, indices_shape, updates_shape):
|
||||
if len(indices_shape) < 2:
|
||||
if indices_shape != [-2] and len(indices_shape) < 2:
|
||||
raise ValueError(f"For '{self.name}', the dimension of 'indices' cannot be less than 2,"
|
||||
f" but got {len(indices_shape)}.")
|
||||
|
||||
if indices_shape[-1] > len(input_x_shape):
|
||||
raise ValueError(f"For '{self.name}', the last dimension of 'indices' must be less than or equal to "
|
||||
f"the dimension of 'input_x', but got the "
|
||||
f"last dimension of 'indices': {indices_shape[-1]} and the dimension of 'input_x': "
|
||||
f"{len(input_x_shape)}.")
|
||||
|
||||
updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
|
||||
if self._check_shape(updates_shape_check, updates_shape) is False:
|
||||
raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to updates_shape_check, "
|
||||
f"where updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] "
|
||||
f"but got the shape of 'update': {updates_shape}, "
|
||||
f"updates_shape_check: {updates_shape_check}, indices_shape: {indices_shape} and "
|
||||
f"input_x_shape: {input_x_shape}. Please check input_x_shape and indices_shape.")
|
||||
if indices_shape[-1] > 0:
|
||||
if indices_shape[-1] > len(input_x_shape):
|
||||
raise ValueError(f"For '{self.name}', the last dimension of 'indices' must be less than or equal to "
|
||||
f"the dimension of 'input_x', but got the "
|
||||
f"last dimension of 'indices': {indices_shape[-1]} and the dimension of 'input_x': "
|
||||
f"{len(input_x_shape)}.")
|
||||
updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
|
||||
if self._check_shape(updates_shape_check, updates_shape) is False:
|
||||
raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to updates_shape_check, "
|
||||
f"where updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] "
|
||||
f"but got the shape of 'update': {updates_shape}, "
|
||||
f"updates_shape_check: {updates_shape_check}, indices_shape: {indices_shape} and "
|
||||
f"input_x_shape: {input_x_shape}. Please check input_x_shape and indices_shape.")
|
||||
|
||||
return input_x_shape
|
||||
|
||||
|
@ -6328,6 +6327,7 @@ class _TensorScatterOp(PrimitiveWithInfer):
|
|||
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name)
|
||||
args = {"input_x": input_x_dtype, "updates": updates_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
|
||||
|
||||
return input_x_dtype
|
||||
|
||||
def _check_shape(self, expect, real):
|
||||
|
|
Loading…
Reference in New Issue