fix vmap error for GatherD and NLLLossGrad, and add english docs for gather_elements tensor interface

This commit is contained in:
TronZhang 2022-07-28 10:22:19 +08:00
parent be8d79b10b
commit c56dc5bcee
6 changed files with 40 additions and 22 deletions

View File

@ -669,18 +669,18 @@ mindspore.Tensor
**参数:**
- **dim** (int) - 获取元素的轴。数据类型为int32或int64。取值范围为[-x_rank, x_rank)。
- **index** (Tensor) - 获取收集元素的索引。支持的数据类型包括int32int64。每个索引元素的取值范围为[-x_rank[dim], x_rank[dim])。
- **dim** (int) - 获取元素的轴。数据类型为int32或int64。取值范围为[-self.ndim, self.ndim)。
- **index** (Tensor) - 获取收集元素的索引。支持的数据类型包括int32int64。每个索引元素的取值范围为[-self.shape(dim), self.shape(dim))。
**返回:**
Tensorshape与 `index` 相同即其shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})`,数据类型与 `x` 相同。
Tensorshape与 `index` 相同即其shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})`,数据类型与 `self.dtype` 相同。
**异常:**
- **TypeError** - `dim``index` 的数据类型既不是int32也不是int64。
- **ValueError** - `x``index` 的维度长度不一致。
- **ValueError** - `x``index``dim` 维外的维度不一致。
- **ValueError** - `self``index` 的维度长度不一致。
- **ValueError** - `self``index``dim` 维外的维度不一致。
- **ValueError** - `dim` 的值不在合理范围内。
.. py:method:: gather_nd(indices)

View File

@ -19,8 +19,8 @@ mindspore.ops.gather_elements
参数:
- **x** (Tensor) - 输入Tensor。
- **dim** (int) - 获取元素的轴。数据类型为int32或int64。取值范围为[-x_rank, x_rank)。
- **index** (Tensor) - 获取收集元素的索引。支持的数据类型包括int32int64。每个索引元素的取值范围为[-x_rank[dim], x_rank[dim])。
- **dim** (int) - 获取元素的轴。数据类型为int32或int64。取值范围为[-x.ndim, x.ndim)。
- **index** (Tensor) - 获取收集元素的索引。支持的数据类型包括int32int64。每个索引元素的取值范围为[-x.shape(dim), x.shape(dim))。
返回:
Tensorshape与 `index` 相同即其shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})`,数据类型与 `x` 相同。

View File

@ -3978,20 +3978,36 @@ class Tensor(Tensor_):
def gather_elements(self, dim, index):
"""
Gathers elements along an axis specified by dim.
Refer to :func:`mindspore.ops.gather_elements` for more detail.
For a 3-D tensor, the output is:
.. code-block::
output[i][j][k] = x[index[i][j][k]][j][k] # if dim == 0
output[i][j][k] = x[i][index[i][j][k]][k] # if dim == 1
output[i][j][k] = x[i][j][index[i][j][k]] # if dim == 2
`x` and `index` have the same length of dimensions, and all dimensions except `dim` have the same size.
If `dim` = i, `x` is an n-D tensor with shape :math:`(z_0, z_1, ..., z_i, ..., z_{n-1})`,
the `index` must be an n-D tensor with shape :math:`(z_0, z_1, ..., y, ..., z_{n-1})`
where `y`>=1 and the output will have the same shape with `index`.
Args:
dim (int): The axis along which to index. It must be int32 or int64. The value range is [-x_rank, x_rank).
dim (int): The axis along which to index. It must be int32 or int64.
The value range is [-self.ndim, self.ndim).
index (Tensor): The indices of elements to gather. It can be one of the following data types:
int32, int64. The value range of each index element is [-x_rank[dim], x_rank[dim]).
int32, int64. The value range of each index element is [-self.shape(dim), self.shape(dim)).
Returns:
Tensor, has the same shape as index tensor and same data type as input tensor.
Tensor, has the same shape as index tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_{n-1})`,
and has the same data type with `self.dtype`.
Raises:
TypeError: If dtype of `dim` or `index` is neither int32 nor int64.
ValueError: If length of shape of current tensor is not equal to length of shape of `index`.
ValueError: If the size of the dimension except `dim` is not equal between current tensor and `index`.
ValueError: If length of shape of `self` is not equal to length of shape of `index`.
ValueError: If the size of the dimension except `dim` is not equal between `self` and `index`.
ValueError: If the value of `dim` is not in the expected range.
Supported Platforms:

View File

@ -1174,7 +1174,7 @@ def get_gatherd_vmap_rule(prim, axis_size):
_raise_value_error("The source axis of `dim` in `GatherD` must be None, "
"but got {}.".format(axis_dim))
if not isinstance(dim_value, int):
_raise_value_error("The `dim` in `GatherD` must be a const, but got {}.".format(dim_value))
_raise_value_error("The `dim` in `GatherD` must be a int, but got {}.".format(dim_value))
out_dim = index_dim
@ -1188,7 +1188,8 @@ def get_gatherd_vmap_rule(prim, axis_size):
mnp.moveaxis(x, x_dim, index_dim)
# Adapt `dim` to vmap case.
dim_value = dim_value + 1 if dim_value >= out_dim else dim_value
x_ndim = ops.rank(x)
dim_value = _get_reduce_batch_axis(dim_value, x_dim, x_ndim)
out = prim(x, dim_value, index)
return (out, out_dim)

View File

@ -22,7 +22,7 @@ from mindspore.ops import functional as F
from mindspore.ops import constexpr
from ..primitive import Primitive
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, _bdim_at_front,\
_vmap_clone_prim, _bdim_at_any
_vmap_clone_prim, _bdim_at_any
@vmap_rules_getters.register(G.NLLLossGrad)
@ -67,10 +67,10 @@ def get_nll_loss_grad_vmap_rule(prim, axis_size):
if w_dim is not None or tw_dim is not None:
_raise_value_error("The source axis of weight and total_weight in `NLLLossGrad` must be None for now, "
"but got {} and {}.".format(w_dim, tw_dim))
if lg_dim is not None and (base_x_len != 2 or reduction_type != "none"):
if lg_dim is not None and reduction_type != "none":
_raise_value_error("The source axis of loss_grad in `NLLLossGrad` can be not None "
"just when x is 2d and reduction type is none, "
"but x is {}d and reduction type is {}.".format(base_x_len, reduction_type))
"just when reduction type is none for vmap, "
"but reduction type is {}.".format(reduction_type))
# If stacked, move vmap_dim to first dim and reshape to right shape.
if x_dim is not None:

View File

@ -2229,12 +2229,13 @@ def gather_elements(x, dim, index):
Args:
x (Tensor): The input tensor.
dim (int): The axis along which to index. It must be int32 or int64. The value range is [-x_rank, x_rank).
dim (int): The axis along which to index. It must be int32 or int64. The value range is [-x.ndim, x.ndim).
index (Tensor): The indices of elements to gather. It can be one of the following data types:
int32, int64. The value range of each index element is [-x_rank[dim], x_rank[dim]).
int32, int64. The value range of each index element is [-x.shape(dim), x.shape(dim)).
Returns:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_{n-1})`, has the same data type with `x`.
Tensor, has the same shape as index tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_{n-1})`,
and has the same data type with `x`.
Raises:
TypeError: If dtype of `dim` or `index` is neither int32 nor int64.