From 52aaa43a5b94fce69fcf83194a5af0b038ae513d Mon Sep 17 00:00:00 2001 From: mengyuanli Date: Thu, 28 Jul 2022 16:08:29 +0800 Subject: [PATCH] fix bug of example in tensor scatter elements docs --- .../ops/mindspore.ops.func_tensor_scatter_elements.rst | 4 +++- mindspore/python/mindspore/ops/function/array_func.py | 5 +++-- mindspore/python/mindspore/ops/operations/array_ops.py | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_elements.rst b/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_elements.rst index 1f0626ab9bc..0c1784b0bae 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_elements.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_elements.rst @@ -16,7 +16,9 @@ output[i][j][indices[i][j][k]] = updates[i][j][k] # if axis == 2, reduction == "none" .. warning:: - 如果 `indices` 中有多个索引向量对应于同一位置,则输出中该位置值是不确定的。 + - 如果 `indices` 中有多个索引向量对应于同一位置,则输出中该位置值是不确定的。 + - 在Ascend平台上,目前仅支持 `reduction` 设置为"none"的实现。 + - 在Ascend平台上,`input_x` 仅支持float16和float32两种数据类型。 .. note:: 如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新到 `input_x` ,也不会抛出索引错误。 diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index c4aaf8534c3..3bba616bea1 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -2533,6 +2533,7 @@ def tensor_scatter_elements(input_x, indices, updates, axis=0, reduction="none") in `indices` that correspond to the same position, the value of that position in the output will be nondeterministic. - On Ascend, the reduction only support set to "none" for now. + - On Ascend, the data type of `input_x` must be float16 or float32. .. note:: If some values of the `indices` are out of bound, instead of raising an index error, @@ -2574,9 +2575,9 @@ def tensor_scatter_elements(input_x, indices, updates, axis=0, reduction="none") [[ 2.0 3.0 3.0] [ 5.0 5.0 7.0] [ 7.0 9.0 10.0]] - >>> input_x = Parameter(Tensor(np.array([[1, 2, 3, 4, 5]]), mindspore.int32), name="x") + >>> input_x = Parameter(Tensor(np.array([[1, 2, 3, 4, 5]]), mindspore.float32), name="x") >>> indices = Tensor(np.array([[2, 4]]), mindspore.int32) - >>> updates = Tensor(np.array([[8, 8]]), mindspore.int32) + >>> updates = Tensor(np.array([[8, 8]]), mindspore.float32) >>> axis = 1 >>> reduction = "none" >>> output = F.tensor_scatter_elements(input_x, indices, updates, axis, reduction) diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 006fdc786c7..e7d938e5305 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -6905,9 +6905,9 @@ class TensorScatterElements(Primitive): [ 0.0 5.0 0.0] [ 7.0 0.0 0.0]] >>> op = ops.TensorScatterElements(1, "add") - >>> data = Tensor(np.array([[1, 2, 3, 4, 5]), mindspore.int32) + >>> data = Tensor(np.array([[1, 2, 3, 4, 5]), mindspore.float32) >>> indices = Tensor(np.array([[2, 4]), mindspore.int32) - >>> updates = Tensor(np.array([[8, 8]]), mindspore.int32) + >>> updates = Tensor(np.array([[8, 8]]), mindspore.float32) >>> output = op(data, indices, updates) >>> print(output) [[ 1 2 11 4 13]]