forked from mindspore-Ecosystem/mindspore
!33758 tensor scatter div dosc
Merge pull request !33758 from ling/code_dosc_sr
This commit is contained in:
commit
02ed0039e3
|
@ -279,6 +279,8 @@ Array操作
|
|||
- Refer to :class:`mindspore.ops.StridedSlice`.
|
||||
* - mindspore.ops.tensor_scatter_add
|
||||
- Refer to :class:`mindspore.ops.TensorScatterAdd`.
|
||||
* - mindspore.ops.tensor_scatter_div
|
||||
- Refer to :class:`mindspore.ops.TensorScatterDiv`.
|
||||
* - mindspore.ops.tensor_scatter_update
|
||||
- Refer to :class:`mindspore.ops.TensorScatterUpdate`.
|
||||
* - mindspore.ops.tensor_slice
|
||||
|
|
|
@ -686,6 +686,28 @@ mindspore.Tensor
|
|||
|
||||
- **ValueError** - `axis` 超出范围,或 `mode` 被设置为'raise'、'wrap'和'clip'以外的值。
|
||||
|
||||
.. py:method:: tensor_scatter_div(indices, updates)
|
||||
|
||||
根据指定的索引, 通过除法进行计算, 将输出赋值到输出Tensor中。
|
||||
|
||||
.. note::
|
||||
- 如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新为 `input_x` ,而不是抛出索引错误。
|
||||
- 算子无法处理除0异常, 用户需保证 `updates `中没有0值。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **indices** (Tensor) - 该Tensor的索引,数据类型为int32或int64的。其rank必须至少为2。
|
||||
- **updates** (Tensor) - 指定与 `input_x` 相加操作的Tensor,其数据类型与输入相同。updates.shape应等于indices.shape[:-1] + input_x.shape[indices.shape[-1]:]。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,shape和数据类型与该Tensor相同。
|
||||
|
||||
**异常:**
|
||||
|
||||
**TypeError** - `indices`的数据类型不是int32,也不是int64。
|
||||
**ValueError** - Tensor的shape长度小于`indices`的shape的最后一个维度。
|
||||
|
||||
.. py:method:: to_tensor(slice_index=None, shape=None, opt_shard_group=None)
|
||||
|
||||
返回init_data()的结果,并获取此Tensor的数据。
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
mindspore.ops.TensorScatterDiv
|
||||
===============================
|
||||
|
||||
.. py:class:: mindspore.ops.TensorScatterDiv
|
||||
|
||||
根据索引,通过相除运算得到输出Tensor的值。此操作与 :class:`mindspore.ops.ScatterNdDiv` 类似,只是更新后的结果是通过算子output返回,而不是直接原地更新input。
|
||||
|
||||
更多参考相见 :func:`mindspore.op.tensor_scatter_div`。
|
|
@ -0,0 +1,28 @@
|
|||
mindspore.ops.TensorScatterDiv
|
||||
===============================
|
||||
|
||||
.. py:class:: mindspore.ops.TensorScatterDiv
|
||||
|
||||
根据索引,通过相除运算得到输出Tensor的值。此操作与 :class:`mindspore.ops.ScatterNdDiv` 类似,只是更新后的结果是通过算子output返回,而不是直接原地更新input。
|
||||
|
||||
`indices` 的最后一个轴是每个索引向量的深度。对于每个索引向量, `updates` 中必须有相应的值。 `updates` 的shape应该等于 `input_x[indices]` 的shape。有关更多详细信息,请参见使用用例。
|
||||
|
||||
.. note::
|
||||
- 如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新为 `input_x` ,而不是抛出索引错误。
|
||||
- 算子无法处理除0异常, 用户需保证 `updates `中没有0值。
|
||||
|
||||
|
||||
**输入:**
|
||||
|
||||
- **input_x** (Tensor) - 输入Tensor。 `input_x` 的维度必须不小于indices.shape[-1]。
|
||||
- **indices** (Tensor) - 输入Tensor的索引,数据类型为int32或int64的。其rank必须至少为2。
|
||||
- **updates** (Tensor) - 指定与 `input_x` 相加操作的Tensor,其数据类型与输入相同。updates.shape应等于indices.shape[:-1] + input_x.shape[indices.shape[-1]:]。
|
||||
|
||||
**输出:**
|
||||
|
||||
Tensor,shape和数据类型与输入 `input_x` 相同。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `indices` 的数据类型既不是int32,也不是int64。
|
||||
- **ValueError** - `input_x` 的shape长度小于 `indices` 的shape的最后一个维度。
|
|
@ -251,6 +251,7 @@ Array Operation
|
|||
mindspore.ops.select
|
||||
mindspore.ops.shape
|
||||
mindspore.ops.size
|
||||
mindspore.ops.tensor_scatter_div
|
||||
mindspore.ops.tile
|
||||
mindspore.ops.transpose
|
||||
mindspore.ops.unique
|
||||
|
|
Loading…
Reference in New Issue