mindspore/docs/api/api_python/ops/mindspore.ops.GatherD.rst

33 lines
1.3 KiB
ReStructuredText
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

mindspore.ops.GatherD
=======================
.. py:class:: mindspore.ops.GatherD
沿指定轴收集元素。
对于三维Tensor输出为
.. code-block::
output[i][j][k] = x[index[i][j][k]][j][k] # if dim == 0
output[i][j][k] = x[i][index[i][j][k]][k] # if dim == 1
output[i][j][k] = x[i][j][index[i][j][k]] # if dim == 2
如果 `x` 是shape为 :math:`(z_0, z_1, ..., z_i, ..., z_{n-1})` ,维度为 `dim` = i的n维Tensor`index` 必须是shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})` 的n维Tensor其中 `y` 大于等于1输出的shape与 `index` 相同。
**输入:**
- **x** (Tensor) - GatherD的输入任意维度的Tensor。
- **dim** (int) - 指定索引的轴。数据类型为int32或int64。只能是常量值。
- **index** (Tensor) - 指定收集元素的索引。支持的数据类型包括int32int64。每个索引元素的取值范围为[-x_rank[dim], x_rank[dim])。
**输出:**
Tensorshape为 :math:`(z_1, z_2, ..., z_N)` 的Tensor数据类型与 `x` 相同。
**异常:**
- **TypeError** - `dim``index` 的数据类型既不是int32也不是int64。
- **ValueError** - `x` 的shape长度不等于 `index` 的shape长度。