fix gather_nd tensor rst

This commit is contained in:
zengzitao 2022-05-23 18:56:09 +08:00
parent a3e00bd637
commit b75909d878
2 changed files with 18 additions and 16 deletions

View File

@ -385,30 +385,31 @@ mindspore.Tensor
与输入的张量具有相同的数据类型的Tensor。
.. py:method:: gather_nd(input_x, indices)
.. py:method:: gather_nd(indices)
按索引从张量中获取切片。
使用给定的索引从具有指定形状的张量中搜集切片。
`indices` 是一个K维的整数张量假定它的K-1维张量中的每一个元素是 `input_x` 的切片,那么有:
按索引从输入Tensor中获取切片。
使用给定的索引从具有指定形状的输入Tensor中搜集切片。
输入Tensor的shape是 :math:`(N,*)` ,其中 :math:`*` 表示任意数量的附加维度。为了表达方便,
将其定义为`input_x`
`indices` 是一个K维的整数张量假定它的K-1维张量中的每一个元素是输入Tensor的切片那么有
.. math::
output[(i_0, ..., i_{K-2})] = input\_x[indices[(i_0, ..., i_{K-2})]]
`indices` 的最后一维不能超过 `input_x` 的秩:
`indices` 的最后一维不能超过输入Tensor的秩:
:math:`indices.shape[-1] <= input\_x.rank`
**参数:**
- **input_x** (Tensor) - 待搜集元素的目标张量它的shape是 :math:`(N,*)` ,其中 :math:`*` 表示任意数量的附加维度。
- **indices** (Tensor) - 获取收集元素的索引张量其数据类型包括int32int64。
**返回:**
Tensor具有与入参 `input_x` 相同的数据类型shape维度为indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
Tensor具有与输入Tensor相同的数据类型shape维度为 :math:`indices\_shape[:-1] + input\_x\_shape[indices\_shape[-1]:]`
**异常:**
- **ValueError** - 如果 `input_x` 的shape长度小于 `indices` 的最后一个维度。
- **ValueError** - 如果输入Tensor的shape长度小于 `indices` 的最后一个维度。
.. py:method:: ger(x)

View File

@ -2652,27 +2652,28 @@ class Tensor(Tensor_):
def gather_nd(self, indices):
r"""
Gathers slices from a tensor by indices.
Using given indices to gather slices from a tensor with a specified shape.
`input_x` is a target tensor of the element to be collected,
The shape is :math:`(N,*)` where :math:`*` means any number of additional dimensions.
Gathers slices from a input tensor by indices.
Using given indices to gather slices from a input tensor with a specified shape.
input tensor's shape is :math:`(N,*)` where :math:`*` means any number of additional dimensions. For convenience
define it as `input_x`.
`indices` is an K-dimensional integer tensor. Suppose that it is a (K-1)-dimensional tensor and each element
of it defines a slice of `input_x`:
of it defines a slice of input tensor:
.. math::
output[(i_0, ..., i_{K-2})] = input\_x[indices[(i_0, ..., i_{K-2})]]
The last dimension of `indices` can not more than the rank of `input_x`:
The last dimension of `indices` can not more than the rank of input tensor:
:math:`indices.shape[-1] <= input\_x.rank`.
Args:
indices (Tensor): The index tensor that gets the collected elements, with int32 or int64 data type.
Returns:
Tensor, has the same type as `input_x` and the shape is indices_shape[:-1] + x_shape[indices_shape[-1]:].
Tensor, has the same type as input tensor and the shape is:
:math:`indices\_shape[:-1] + input\_x\_shape[indices\_shape[-1]:]`.
Raises:
ValueError: If length of shape of `input_x` is less than the last dimension of `indices`.
ValueError: If length of shape of input tensor is less than the last dimension of `indices`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``