fix bug of example in tensor scatter elements docs

This commit is contained in:
mengyuanli 2022-07-28 16:08:29 +08:00
parent f2e8fdc9ac
commit 52aaa43a5b
3 changed files with 8 additions and 5 deletions

View File

@ -16,7 +16,9 @@
output[i][j][indices[i][j][k]] = updates[i][j][k] # if axis == 2, reduction == "none"
.. warning::
如果 `indices` 中有多个索引向量对应于同一位置,则输出中该位置值是不确定的。
- 如果 `indices` 中有多个索引向量对应于同一位置,则输出中该位置值是不确定的。
- 在Ascend平台上目前仅支持 `reduction` 设置为"none"的实现。
- 在Ascend平台上`input_x` 仅支持float16和float32两种数据类型。
.. note::
如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新到 `input_x` ,也不会抛出索引错误。

View File

@ -2533,6 +2533,7 @@ def tensor_scatter_elements(input_x, indices, updates, axis=0, reduction="none")
in `indices` that correspond to the same position, the value of that position in the output will be
nondeterministic.
- On Ascend, the reduction only support set to "none" for now.
- On Ascend, the data type of `input_x` must be float16 or float32.
.. note::
If some values of the `indices` are out of bound, instead of raising an index error,
@ -2574,9 +2575,9 @@ def tensor_scatter_elements(input_x, indices, updates, axis=0, reduction="none")
[[ 2.0 3.0 3.0]
[ 5.0 5.0 7.0]
[ 7.0 9.0 10.0]]
>>> input_x = Parameter(Tensor(np.array([[1, 2, 3, 4, 5]]), mindspore.int32), name="x")
>>> input_x = Parameter(Tensor(np.array([[1, 2, 3, 4, 5]]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([[2, 4]]), mindspore.int32)
>>> updates = Tensor(np.array([[8, 8]]), mindspore.int32)
>>> updates = Tensor(np.array([[8, 8]]), mindspore.float32)
>>> axis = 1
>>> reduction = "none"
>>> output = F.tensor_scatter_elements(input_x, indices, updates, axis, reduction)

View File

@ -6905,9 +6905,9 @@ class TensorScatterElements(Primitive):
[ 0.0 5.0 0.0]
[ 7.0 0.0 0.0]]
>>> op = ops.TensorScatterElements(1, "add")
>>> data = Tensor(np.array([[1, 2, 3, 4, 5]), mindspore.int32)
>>> data = Tensor(np.array([[1, 2, 3, 4, 5]), mindspore.float32)
>>> indices = Tensor(np.array([[2, 4]), mindspore.int32)
>>> updates = Tensor(np.array([[8, 8]]), mindspore.int32)
>>> updates = Tensor(np.array([[8, 8]]), mindspore.float32)
>>> output = op(data, indices, updates)
>>> print(output)
[[ 1 2 11 4 13]]