From dc548afb93e8a14ff08c8fcb7a6f1be7201c606d Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Tue, 9 Jun 2020 16:35:30 +0800 Subject: [PATCH] fixed ScatterUpdate --- mindspore/ops/operations/array_ops.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 4042be84cc2..3df558ace6a 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2032,7 +2032,7 @@ class ScatterNd(PrimitiveWithInfer): Creates an empty tensor, and set values by scattering the update tensor depending on indices. Inputs: - - **indices** (Tensor) - The index of scattering in the new tensor. + - **indices** (Tensor) - The index of scattering in the new tensor. With int32 data type. - **update** (Tensor) - The source Tensor to be scattered. - **shape** (tuple[int]) - Define the shape of the output tensor. Has the same type as indices. @@ -2055,7 +2055,7 @@ class ScatterNd(PrimitiveWithInfer): def __infer__(self, indices, update, shape): shp = shape['value'] validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name) - validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) + validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name) validator.check_value_type("shape", shp, [tuple], self.name) for i, x in enumerate(shp): validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name) @@ -2159,7 +2159,7 @@ class ScatterUpdate(PrimitiveWithInfer): Inputs: - **input_x** (Parameter) - The target tensor, with data type of Parameter. - - **indices** (Tensor) - The index of input tensor. + - **indices** (Tensor) - The index of input tensor. With int32 data type. - **update** (Tensor) - The tensor to update the input tensor, has the same type as input, and update.shape = indices.shape + input_x.shape[1:]. @@ -2167,9 +2167,11 @@ class ScatterUpdate(PrimitiveWithInfer): Tensor, has the same shape and type as `input_x`. Examples: - >>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)) + >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]) + >>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x") >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) - >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) + >>> np_update = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]]) + >>> update = Tensor(np_update, mindspore.float32) >>> op = P.ScatterUpdate() >>> output = op(input_x, indices, update) """ @@ -2181,6 +2183,7 @@ class ScatterUpdate(PrimitiveWithInfer): @prim_attr_register def __init__(self, use_locking=True): """Init ScatterUpdate""" + validator.check_value_type('use_locking', use_locking, [bool], self.name) self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) def infer_shape(self, x_shape, indices_shape, value_shape): @@ -2189,7 +2192,7 @@ class ScatterUpdate(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype, indices_dtype, value_dtype): - validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) + validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) args = {"x": x_dtype, "value": value_dtype} validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) return x_dtype @@ -2206,14 +2209,15 @@ class ScatterNdUpdate(PrimitiveWithInfer): Inputs: - **input_x** (Parameter) - The target tensor, with data type of Parameter. - - **indices** (Tensor) - The index of input tensor. + - **indices** (Tensor) - The index of input tensor, with int32 data type. - **update** (Tensor) - The tensor to add to the input tensor, has the same type as input. Outputs: Tensor, has the same shape and type as `input_x`. Examples: - >>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)) + >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]) + >>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x") >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) >>> op = P.ScatterNdUpdate() @@ -2227,6 +2231,7 @@ class ScatterNdUpdate(PrimitiveWithInfer): @prim_attr_register def __init__(self, use_locking=True): """Init ScatterNdUpdate""" + validator.check_value_type('use_locking', use_locking, [bool], self.name) self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) def infer_shape(self, x_shape, indices_shape, value_shape): @@ -2237,7 +2242,7 @@ class ScatterNdUpdate(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype, indices_dtype, value_dtype): - validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) + validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) args = {"x": x_dtype, "value": value_dtype} validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) return x_dtype