forked from mindspore-Ecosystem/mindspore
fixed ScatterUpdate
This commit is contained in:
parent
ed77c761ec
commit
dc548afb93
|
@ -2032,7 +2032,7 @@ class ScatterNd(PrimitiveWithInfer):
|
|||
Creates an empty tensor, and set values by scattering the update tensor depending on indices.
|
||||
|
||||
Inputs:
|
||||
- **indices** (Tensor) - The index of scattering in the new tensor.
|
||||
- **indices** (Tensor) - The index of scattering in the new tensor. With int32 data type.
|
||||
- **update** (Tensor) - The source Tensor to be scattered.
|
||||
- **shape** (tuple[int]) - Define the shape of the output tensor. Has the same type as indices.
|
||||
|
||||
|
@ -2055,7 +2055,7 @@ class ScatterNd(PrimitiveWithInfer):
|
|||
def __infer__(self, indices, update, shape):
|
||||
shp = shape['value']
|
||||
validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name)
|
||||
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
|
||||
validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name)
|
||||
validator.check_value_type("shape", shp, [tuple], self.name)
|
||||
for i, x in enumerate(shp):
|
||||
validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name)
|
||||
|
@ -2159,7 +2159,7 @@ class ScatterUpdate(PrimitiveWithInfer):
|
|||
|
||||
Inputs:
|
||||
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
|
||||
- **indices** (Tensor) - The index of input tensor.
|
||||
- **indices** (Tensor) - The index of input tensor. With int32 data type.
|
||||
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
||||
and update.shape = indices.shape + input_x.shape[1:].
|
||||
|
||||
|
@ -2167,9 +2167,11 @@ class ScatterUpdate(PrimitiveWithInfer):
|
|||
Tensor, has the same shape and type as `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
|
||||
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
|
||||
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
|
||||
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
|
||||
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
|
||||
>>> np_update = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
|
||||
>>> update = Tensor(np_update, mindspore.float32)
|
||||
>>> op = P.ScatterUpdate()
|
||||
>>> output = op(input_x, indices, update)
|
||||
"""
|
||||
|
@ -2181,6 +2183,7 @@ class ScatterUpdate(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, use_locking=True):
|
||||
"""Init ScatterUpdate"""
|
||||
validator.check_value_type('use_locking', use_locking, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, value_shape):
|
||||
|
@ -2189,7 +2192,7 @@ class ScatterUpdate(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
|
||||
args = {"x": x_dtype, "value": value_dtype}
|
||||
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
@ -2206,14 +2209,15 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|||
|
||||
Inputs:
|
||||
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
|
||||
- **indices** (Tensor) - The index of input tensor.
|
||||
- **indices** (Tensor) - The index of input tensor, with int32 data type.
|
||||
- **update** (Tensor) - The tensor to add to the input tensor, has the same type as input.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and type as `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
|
||||
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
|
||||
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
|
||||
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
|
||||
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
|
||||
>>> op = P.ScatterNdUpdate()
|
||||
|
@ -2227,6 +2231,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, use_locking=True):
|
||||
"""Init ScatterNdUpdate"""
|
||||
validator.check_value_type('use_locking', use_locking, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, value_shape):
|
||||
|
@ -2237,7 +2242,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
|
||||
args = {"x": x_dtype, "value": value_dtype}
|
||||
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
|
Loading…
Reference in New Issue