!33758 tensor scatter div dosc

Merge pull request !33758 from ling/code_dosc_sr
This commit is contained in:
i-robot 2022-04-29 04:07:09 +00:00 committed by Gitee
commit 02ed0039e3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 61 additions and 0 deletions

View File

@ -279,6 +279,8 @@ Array操作
- Refer to :class:`mindspore.ops.StridedSlice`.
* - mindspore.ops.tensor_scatter_add
- Refer to :class:`mindspore.ops.TensorScatterAdd`.
* - mindspore.ops.tensor_scatter_div
- Refer to :class:`mindspore.ops.TensorScatterDiv`.
* - mindspore.ops.tensor_scatter_update
- Refer to :class:`mindspore.ops.TensorScatterUpdate`.
* - mindspore.ops.tensor_slice

View File

@ -686,6 +686,28 @@ mindspore.Tensor
- **ValueError** - `axis` 超出范围,或 `mode` 被设置为'raise'、'wrap'和'clip'以外的值。
.. py:method:: tensor_scatter_div(indices, updates)
根据指定的索引, 通过除法进行计算, 将输出赋值到输出Tensor中。
.. note::
- 如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新为 `input_x` ,而不是抛出索引错误。
- 算子无法处理除0异常, 用户需保证 `updates `中没有0值。
**参数:**
- **indices** (Tensor) - 该Tensor的索引数据类型为int32或int64的。其rank必须至少为2。
- **updates** (Tensor) - 指定与 `input_x` 相加操作的Tensor其数据类型与输入相同。updates.shape应等于indices.shape[:-1] + input_x.shape[indices.shape[-1]:]。
**返回:**
Tensorshape和数据类型与该Tensor相同。
**异常:**
**TypeError** - `indices`的数据类型不是int32也不是int64。
**ValueError** - Tensor的shape长度小于`indices`的shape的最后一个维度。
.. py:method:: to_tensor(slice_index=None, shape=None, opt_shard_group=None)
返回init_data()的结果并获取此Tensor的数据。

View File

@ -0,0 +1,8 @@
mindspore.ops.TensorScatterDiv
===============================
.. py:class:: mindspore.ops.TensorScatterDiv
根据索引通过相除运算得到输出Tensor的值。此操作与 :class:`mindspore.ops.ScatterNdDiv` 类似只是更新后的结果是通过算子output返回而不是直接原地更新input。
更多参考相见 :func:`mindspore.op.tensor_scatter_div`

View File

@ -0,0 +1,28 @@
mindspore.ops.TensorScatterDiv
===============================
.. py:class:: mindspore.ops.TensorScatterDiv
根据索引通过相除运算得到输出Tensor的值。此操作与 :class:`mindspore.ops.ScatterNdDiv` 类似只是更新后的结果是通过算子output返回而不是直接原地更新input。
`indices` 的最后一个轴是每个索引向量的深度。对于每个索引向量, `updates` 中必须有相应的值。 `updates` 的shape应该等于 `input_x[indices]` 的shape。有关更多详细信息请参见使用用例。
.. note::
- 如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新为 `input_x` ,而不是抛出索引错误。
- 算子无法处理除0异常, 用户需保证 `updates `中没有0值。
**输入:**
- **input_x** (Tensor) - 输入Tensor。 `input_x` 的维度必须不小于indices.shape[-1]。
- **indices** (Tensor) - 输入Tensor的索引数据类型为int32或int64的。其rank必须至少为2。
- **updates** (Tensor) - 指定与 `input_x` 相加操作的Tensor其数据类型与输入相同。updates.shape应等于indices.shape[:-1] + input_x.shape[indices.shape[-1]:]。
**输出:**
Tensorshape和数据类型与输入 `input_x` 相同。
**异常:**
- **TypeError** - `indices` 的数据类型既不是int32也不是int64。
- **ValueError** - `input_x` 的shape长度小于 `indices` 的shape的最后一个维度。

View File

@ -251,6 +251,7 @@ Array Operation
mindspore.ops.select
mindspore.ops.shape
mindspore.ops.size
mindspore.ops.tensor_scatter_div
mindspore.ops.tile
mindspore.ops.transpose
mindspore.ops.unique