forked from mindspore-Ecosystem/mindspore
debug
This commit is contained in:
parent
cc66b9879e
commit
6bc07a5b58
|
@ -6304,23 +6304,22 @@ class _TensorScatterOp(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def infer_shape(self, input_x_shape, indices_shape, updates_shape):
|
def infer_shape(self, input_x_shape, indices_shape, updates_shape):
|
||||||
if len(indices_shape) < 2:
|
if indices_shape != [-2] and len(indices_shape) < 2:
|
||||||
raise ValueError(f"For '{self.name}', the dimension of 'indices' cannot be less than 2,"
|
raise ValueError(f"For '{self.name}', the dimension of 'indices' cannot be less than 2,"
|
||||||
f" but got {len(indices_shape)}.")
|
f" but got {len(indices_shape)}.")
|
||||||
|
if indices_shape[-1] > 0:
|
||||||
if indices_shape[-1] > len(input_x_shape):
|
if indices_shape[-1] > len(input_x_shape):
|
||||||
raise ValueError(f"For '{self.name}', the last dimension of 'indices' must be less than or equal to "
|
raise ValueError(f"For '{self.name}', the last dimension of 'indices' must be less than or equal to "
|
||||||
f"the dimension of 'input_x', but got the "
|
f"the dimension of 'input_x', but got the "
|
||||||
f"last dimension of 'indices': {indices_shape[-1]} and the dimension of 'input_x': "
|
f"last dimension of 'indices': {indices_shape[-1]} and the dimension of 'input_x': "
|
||||||
f"{len(input_x_shape)}.")
|
f"{len(input_x_shape)}.")
|
||||||
|
updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
|
||||||
updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
|
if self._check_shape(updates_shape_check, updates_shape) is False:
|
||||||
if self._check_shape(updates_shape_check, updates_shape) is False:
|
raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to updates_shape_check, "
|
||||||
raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to updates_shape_check, "
|
f"where updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] "
|
||||||
f"where updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] "
|
f"but got the shape of 'update': {updates_shape}, "
|
||||||
f"but got the shape of 'update': {updates_shape}, "
|
f"updates_shape_check: {updates_shape_check}, indices_shape: {indices_shape} and "
|
||||||
f"updates_shape_check: {updates_shape_check}, indices_shape: {indices_shape} and "
|
f"input_x_shape: {input_x_shape}. Please check input_x_shape and indices_shape.")
|
||||||
f"input_x_shape: {input_x_shape}. Please check input_x_shape and indices_shape.")
|
|
||||||
|
|
||||||
return input_x_shape
|
return input_x_shape
|
||||||
|
|
||||||
|
@ -6328,6 +6327,7 @@ class _TensorScatterOp(PrimitiveWithInfer):
|
||||||
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name)
|
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name)
|
||||||
args = {"input_x": input_x_dtype, "updates": updates_dtype}
|
args = {"input_x": input_x_dtype, "updates": updates_dtype}
|
||||||
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
|
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
|
||||||
|
|
||||||
return input_x_dtype
|
return input_x_dtype
|
||||||
|
|
||||||
def _check_shape(self, expect, real):
|
def _check_shape(self, expect, real):
|
||||||
|
|
Loading…
Reference in New Issue