!1930 fix validator for ScatterNdUpdate

Merge pull request !1930 from jiangjinsheng/issue_doc
This commit is contained in:
mindspore-ci-bot 2020-06-11 11:04:34 +08:00 committed by Gitee
commit cc0add562b
1 changed files with 14 additions and 9 deletions

View File

@ -2032,7 +2032,7 @@ class ScatterNd(PrimitiveWithInfer):
Creates an empty tensor, and set values by scattering the update tensor depending on indices. Creates an empty tensor, and set values by scattering the update tensor depending on indices.
Inputs: Inputs:
- **indices** (Tensor) - The index of scattering in the new tensor. - **indices** (Tensor) - The index of scattering in the new tensor. With int32 data type.
- **update** (Tensor) - The source Tensor to be scattered. - **update** (Tensor) - The source Tensor to be scattered.
- **shape** (tuple[int]) - Define the shape of the output tensor. Has the same type as indices. - **shape** (tuple[int]) - Define the shape of the output tensor. Has the same type as indices.
@ -2055,7 +2055,7 @@ class ScatterNd(PrimitiveWithInfer):
def __infer__(self, indices, update, shape): def __infer__(self, indices, update, shape):
shp = shape['value'] shp = shape['value']
validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name) validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name)
validator.check_value_type("shape", shp, [tuple], self.name) validator.check_value_type("shape", shp, [tuple], self.name)
for i, x in enumerate(shp): for i, x in enumerate(shp):
validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name) validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name)
@ -2159,7 +2159,7 @@ class ScatterUpdate(PrimitiveWithInfer):
Inputs: Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter. - **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **indices** (Tensor) - The index of input tensor. - **indices** (Tensor) - The index of input tensor. With int32 data type.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input, - **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape = indices.shape + input_x.shape[1:]. and update.shape = indices.shape + input_x.shape[1:].
@ -2167,9 +2167,11 @@ class ScatterUpdate(PrimitiveWithInfer):
Tensor, has the same shape and type as `input_x`. Tensor, has the same shape and type as `input_x`.
Examples: Examples:
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)) >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) >>> np_update = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
>>> update = Tensor(np_update, mindspore.float32)
>>> op = P.ScatterUpdate() >>> op = P.ScatterUpdate()
>>> output = op(input_x, indices, update) >>> output = op(input_x, indices, update)
""" """
@ -2181,6 +2183,7 @@ class ScatterUpdate(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, use_locking=True): def __init__(self, use_locking=True):
"""Init ScatterUpdate""" """Init ScatterUpdate"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, value_shape): def infer_shape(self, x_shape, indices_shape, value_shape):
@ -2189,7 +2192,7 @@ class ScatterUpdate(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype): def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype} args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype return x_dtype
@ -2206,14 +2209,15 @@ class ScatterNdUpdate(PrimitiveWithInfer):
Inputs: Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter. - **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **indices** (Tensor) - The index of input tensor. - **indices** (Tensor) - The index of input tensor, with int32 data type.
- **update** (Tensor) - The tensor to add to the input tensor, has the same type as input. - **update** (Tensor) - The tensor to add to the input tensor, has the same type as input.
Outputs: Outputs:
Tensor, has the same shape and type as `input_x`. Tensor, has the same shape and type as `input_x`.
Examples: Examples:
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)) >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.ScatterNdUpdate() >>> op = P.ScatterNdUpdate()
@ -2227,6 +2231,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, use_locking=True): def __init__(self, use_locking=True):
"""Init ScatterNdUpdate""" """Init ScatterNdUpdate"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, value_shape): def infer_shape(self, x_shape, indices_shape, value_shape):
@ -2237,7 +2242,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype): def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype} args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype return x_dtype