From b75909d878cfd7d96cda8cf359085308347a5915 Mon Sep 17 00:00:00 2001 From: zengzitao Date: Mon, 23 May 2022 18:56:09 +0800 Subject: [PATCH] fix gather_nd tensor rst --- .../api_python/mindspore/mindspore.Tensor.rst | 17 +++++++++-------- mindspore/python/mindspore/common/tensor.py | 17 +++++++++-------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/docs/api/api_python/mindspore/mindspore.Tensor.rst b/docs/api/api_python/mindspore/mindspore.Tensor.rst index e5dd1a131e0..1f8ca1231dc 100644 --- a/docs/api/api_python/mindspore/mindspore.Tensor.rst +++ b/docs/api/api_python/mindspore/mindspore.Tensor.rst @@ -385,30 +385,31 @@ mindspore.Tensor 与输入的张量具有相同的数据类型的Tensor。 - .. py:method:: gather_nd(input_x, indices) + .. py:method:: gather_nd(indices) - 按索引从张量中获取切片。 - 使用给定的索引从具有指定形状的张量中搜集切片。 - `indices` 是一个K维的整数张量,假定它的K-1维张量中的每一个元素是 `input_x` 的切片,那么有: + 按索引从输入Tensor中获取切片。 + 使用给定的索引从具有指定形状的输入Tensor中搜集切片。 + 输入Tensor的shape是 :math:`(N,*)` ,其中 :math:`*` 表示任意数量的附加维度。为了表达方便, + 将其定义为`input_x`。 + `indices` 是一个K维的整数张量,假定它的K-1维张量中的每一个元素是输入Tensor的切片,那么有: .. math:: output[(i_0, ..., i_{K-2})] = input\_x[indices[(i_0, ..., i_{K-2})]] - `indices` 的最后一维不能超过 `input_x` 的秩: + `indices` 的最后一维不能超过输入Tensor的秩: :math:`indices.shape[-1] <= input\_x.rank`。 **参数:** - - **input_x** (Tensor) - 待搜集元素的目标张量,它的shape是 :math:`(N,*)` ,其中 :math:`*` 表示任意数量的附加维度。 - **indices** (Tensor) - 获取收集元素的索引张量,其数据类型包括:int32,int64。 **返回:** - Tensor,具有与入参 `input_x` 相同的数据类型,shape维度为indices_shape[:-1] + input_x_shape[indices_shape[-1]:]。 + Tensor,具有与输入Tensor相同的数据类型,shape维度为 :math:`indices\_shape[:-1] + input\_x\_shape[indices\_shape[-1]:]`。 **异常:** - - **ValueError** - 如果 `input_x` 的shape长度小于 `indices` 的最后一个维度。 + - **ValueError** - 如果输入Tensor的shape长度小于 `indices` 的最后一个维度。 .. py:method:: ger(x) diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index f05c8c773f6..a787e0ba9ed 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -2652,27 +2652,28 @@ class Tensor(Tensor_): def gather_nd(self, indices): r""" - Gathers slices from a tensor by indices. - Using given indices to gather slices from a tensor with a specified shape. - `input_x` is a target tensor of the element to be collected, - The shape is :math:`(N,*)` where :math:`*` means any number of additional dimensions. + Gathers slices from a input tensor by indices. + Using given indices to gather slices from a input tensor with a specified shape. + input tensor's shape is :math:`(N,*)` where :math:`*` means any number of additional dimensions. For convenience + define it as `input_x`. `indices` is an K-dimensional integer tensor. Suppose that it is a (K-1)-dimensional tensor and each element - of it defines a slice of `input_x`: + of it defines a slice of input tensor: .. math:: output[(i_0, ..., i_{K-2})] = input\_x[indices[(i_0, ..., i_{K-2})]] - The last dimension of `indices` can not more than the rank of `input_x`: + The last dimension of `indices` can not more than the rank of input tensor: :math:`indices.shape[-1] <= input\_x.rank`. Args: indices (Tensor): The index tensor that gets the collected elements, with int32 or int64 data type. Returns: - Tensor, has the same type as `input_x` and the shape is indices_shape[:-1] + x_shape[indices_shape[-1]:]. + Tensor, has the same type as input tensor and the shape is: + :math:`indices\_shape[:-1] + input\_x\_shape[indices\_shape[-1]:]`. Raises: - ValueError: If length of shape of `input_x` is less than the last dimension of `indices`. + ValueError: If length of shape of input tensor is less than the last dimension of `indices`. Supported Platforms: ``Ascend`` ``GPU`` ``CPU``