fixed ScatterUpdate

This commit is contained in:
jiangjinsheng 2020-06-09 16:35:30 +08:00
parent ed77c761ec
commit dc548afb93
1 changed files with 14 additions and 9 deletions

View File

@ -2032,7 +2032,7 @@ class ScatterNd(PrimitiveWithInfer):
Creates an empty tensor, and set values by scattering the update tensor depending on indices.
Inputs:
- **indices** (Tensor) - The index of scattering in the new tensor.
- **indices** (Tensor) - The index of scattering in the new tensor. With int32 data type.
- **update** (Tensor) - The source Tensor to be scattered.
- **shape** (tuple[int]) - Define the shape of the output tensor. Has the same type as indices.
@ -2055,7 +2055,7 @@ class ScatterNd(PrimitiveWithInfer):
def __infer__(self, indices, update, shape):
shp = shape['value']
validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name)
validator.check_value_type("shape", shp, [tuple], self.name)
for i, x in enumerate(shp):
validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name)
@ -2159,7 +2159,7 @@ class ScatterUpdate(PrimitiveWithInfer):
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **indices** (Tensor) - The index of input tensor.
- **indices** (Tensor) - The index of input tensor. With int32 data type.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape = indices.shape + input_x.shape[1:].
@ -2167,9 +2167,11 @@ class ScatterUpdate(PrimitiveWithInfer):
Tensor, has the same shape and type as `input_x`.
Examples:
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> np_update = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
>>> update = Tensor(np_update, mindspore.float32)
>>> op = P.ScatterUpdate()
>>> output = op(input_x, indices, update)
"""
@ -2181,6 +2183,7 @@ class ScatterUpdate(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, use_locking=True):
"""Init ScatterUpdate"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, value_shape):
@ -2189,7 +2192,7 @@ class ScatterUpdate(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype
@ -2206,14 +2209,15 @@ class ScatterNdUpdate(PrimitiveWithInfer):
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **indices** (Tensor) - The index of input tensor.
- **indices** (Tensor) - The index of input tensor, with int32 data type.
- **update** (Tensor) - The tensor to add to the input tensor, has the same type as input.
Outputs:
Tensor, has the same shape and type as `input_x`.
Examples:
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.ScatterNdUpdate()
@ -2227,6 +2231,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, use_locking=True):
"""Init ScatterNdUpdate"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, value_shape):
@ -2237,7 +2242,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype