forked from mindspore-Ecosystem/mindspore
fix bug of example in tensor scatter elements docs
This commit is contained in:
parent
f2e8fdc9ac
commit
52aaa43a5b
|
@ -16,7 +16,9 @@
|
|||
output[i][j][indices[i][j][k]] = updates[i][j][k] # if axis == 2, reduction == "none"
|
||||
|
||||
.. warning::
|
||||
如果 `indices` 中有多个索引向量对应于同一位置,则输出中该位置值是不确定的。
|
||||
- 如果 `indices` 中有多个索引向量对应于同一位置,则输出中该位置值是不确定的。
|
||||
- 在Ascend平台上,目前仅支持 `reduction` 设置为"none"的实现。
|
||||
- 在Ascend平台上,`input_x` 仅支持float16和float32两种数据类型。
|
||||
|
||||
.. note::
|
||||
如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新到 `input_x` ,也不会抛出索引错误。
|
||||
|
|
|
@ -2533,6 +2533,7 @@ def tensor_scatter_elements(input_x, indices, updates, axis=0, reduction="none")
|
|||
in `indices` that correspond to the same position, the value of that position in the output will be
|
||||
nondeterministic.
|
||||
- On Ascend, the reduction only support set to "none" for now.
|
||||
- On Ascend, the data type of `input_x` must be float16 or float32.
|
||||
|
||||
.. note::
|
||||
If some values of the `indices` are out of bound, instead of raising an index error,
|
||||
|
@ -2574,9 +2575,9 @@ def tensor_scatter_elements(input_x, indices, updates, axis=0, reduction="none")
|
|||
[[ 2.0 3.0 3.0]
|
||||
[ 5.0 5.0 7.0]
|
||||
[ 7.0 9.0 10.0]]
|
||||
>>> input_x = Parameter(Tensor(np.array([[1, 2, 3, 4, 5]]), mindspore.int32), name="x")
|
||||
>>> input_x = Parameter(Tensor(np.array([[1, 2, 3, 4, 5]]), mindspore.float32), name="x")
|
||||
>>> indices = Tensor(np.array([[2, 4]]), mindspore.int32)
|
||||
>>> updates = Tensor(np.array([[8, 8]]), mindspore.int32)
|
||||
>>> updates = Tensor(np.array([[8, 8]]), mindspore.float32)
|
||||
>>> axis = 1
|
||||
>>> reduction = "none"
|
||||
>>> output = F.tensor_scatter_elements(input_x, indices, updates, axis, reduction)
|
||||
|
|
|
@ -6905,9 +6905,9 @@ class TensorScatterElements(Primitive):
|
|||
[ 0.0 5.0 0.0]
|
||||
[ 7.0 0.0 0.0]]
|
||||
>>> op = ops.TensorScatterElements(1, "add")
|
||||
>>> data = Tensor(np.array([[1, 2, 3, 4, 5]), mindspore.int32)
|
||||
>>> data = Tensor(np.array([[1, 2, 3, 4, 5]), mindspore.float32)
|
||||
>>> indices = Tensor(np.array([[2, 4]), mindspore.int32)
|
||||
>>> updates = Tensor(np.array([[8, 8]]), mindspore.int32)
|
||||
>>> updates = Tensor(np.array([[8, 8]]), mindspore.float32)
|
||||
>>> output = op(data, indices, updates)
|
||||
>>> print(output)
|
||||
[[ 1 2 11 4 13]]
|
||||
|
|
Loading…
Reference in New Issue