forked from mindspore-Ecosystem/mindspore
fix vmap error for GatherD and NLLLossGrad, and add english docs for gather_elements tensor interface
This commit is contained in:
parent
be8d79b10b
commit
c56dc5bcee
|
@ -669,18 +669,18 @@ mindspore.Tensor
|
|||
|
||||
**参数:**
|
||||
|
||||
- **dim** (int) - 获取元素的轴。数据类型为int32或int64。取值范围为[-x_rank, x_rank)。
|
||||
- **index** (Tensor) - 获取收集元素的索引。支持的数据类型包括:int32,int64。每个索引元素的取值范围为[-x_rank[dim], x_rank[dim])。
|
||||
- **dim** (int) - 获取元素的轴。数据类型为int32或int64。取值范围为[-self.ndim, self.ndim)。
|
||||
- **index** (Tensor) - 获取收集元素的索引。支持的数据类型包括:int32,int64。每个索引元素的取值范围为[-self.shape(dim), self.shape(dim))。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,shape与 `index` 相同,即其shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})`,数据类型与 `x` 相同。
|
||||
Tensor,shape与 `index` 相同,即其shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})`,数据类型与 `self.dtype` 相同。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `dim` 或 `index` 的数据类型既不是int32,也不是int64。
|
||||
- **ValueError** - `x` 和 `index` 的维度长度不一致。
|
||||
- **ValueError** - `x` 和 `index` 除 `dim` 维外的维度不一致。
|
||||
- **ValueError** - `self` 和 `index` 的维度长度不一致。
|
||||
- **ValueError** - `self` 和 `index` 除 `dim` 维外的维度不一致。
|
||||
- **ValueError** - `dim` 的值不在合理范围内。
|
||||
|
||||
.. py:method:: gather_nd(indices)
|
||||
|
|
|
@ -19,8 +19,8 @@ mindspore.ops.gather_elements
|
|||
|
||||
参数:
|
||||
- **x** (Tensor) - 输入Tensor。
|
||||
- **dim** (int) - 获取元素的轴。数据类型为int32或int64。取值范围为[-x_rank, x_rank)。
|
||||
- **index** (Tensor) - 获取收集元素的索引。支持的数据类型包括:int32,int64。每个索引元素的取值范围为[-x_rank[dim], x_rank[dim])。
|
||||
- **dim** (int) - 获取元素的轴。数据类型为int32或int64。取值范围为[-x.ndim, x.ndim)。
|
||||
- **index** (Tensor) - 获取收集元素的索引。支持的数据类型包括:int32,int64。每个索引元素的取值范围为[-x.shape(dim), x.shape(dim))。
|
||||
|
||||
返回:
|
||||
Tensor,shape与 `index` 相同,即其shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})`,数据类型与 `x` 相同。
|
||||
|
|
|
@ -3978,20 +3978,36 @@ class Tensor(Tensor_):
|
|||
def gather_elements(self, dim, index):
|
||||
"""
|
||||
Gathers elements along an axis specified by dim.
|
||||
Refer to :func:`mindspore.ops.gather_elements` for more detail.
|
||||
|
||||
For a 3-D tensor, the output is:
|
||||
|
||||
.. code-block::
|
||||
|
||||
output[i][j][k] = x[index[i][j][k]][j][k] # if dim == 0
|
||||
|
||||
output[i][j][k] = x[i][index[i][j][k]][k] # if dim == 1
|
||||
|
||||
output[i][j][k] = x[i][j][index[i][j][k]] # if dim == 2
|
||||
|
||||
`x` and `index` have the same length of dimensions, and all dimensions except `dim` have the same size.
|
||||
If `dim` = i, `x` is an n-D tensor with shape :math:`(z_0, z_1, ..., z_i, ..., z_{n-1})`,
|
||||
the `index` must be an n-D tensor with shape :math:`(z_0, z_1, ..., y, ..., z_{n-1})`
|
||||
where `y`>=1 and the output will have the same shape with `index`.
|
||||
|
||||
Args:
|
||||
dim (int): The axis along which to index. It must be int32 or int64. The value range is [-x_rank, x_rank).
|
||||
dim (int): The axis along which to index. It must be int32 or int64.
|
||||
The value range is [-self.ndim, self.ndim).
|
||||
index (Tensor): The indices of elements to gather. It can be one of the following data types:
|
||||
int32, int64. The value range of each index element is [-x_rank[dim], x_rank[dim]).
|
||||
int32, int64. The value range of each index element is [-self.shape(dim), self.shape(dim)).
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape as index tensor and same data type as input tensor.
|
||||
Tensor, has the same shape as index tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_{n-1})`,
|
||||
and has the same data type with `self.dtype`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `dim` or `index` is neither int32 nor int64.
|
||||
ValueError: If length of shape of current tensor is not equal to length of shape of `index`.
|
||||
ValueError: If the size of the dimension except `dim` is not equal between current tensor and `index`.
|
||||
ValueError: If length of shape of `self` is not equal to length of shape of `index`.
|
||||
ValueError: If the size of the dimension except `dim` is not equal between `self` and `index`.
|
||||
ValueError: If the value of `dim` is not in the expected range.
|
||||
|
||||
Supported Platforms:
|
||||
|
|
|
@ -1174,7 +1174,7 @@ def get_gatherd_vmap_rule(prim, axis_size):
|
|||
_raise_value_error("The source axis of `dim` in `GatherD` must be None, "
|
||||
"but got {}.".format(axis_dim))
|
||||
if not isinstance(dim_value, int):
|
||||
_raise_value_error("The `dim` in `GatherD` must be a const, but got {}.".format(dim_value))
|
||||
_raise_value_error("The `dim` in `GatherD` must be a int, but got {}.".format(dim_value))
|
||||
|
||||
out_dim = index_dim
|
||||
|
||||
|
@ -1188,7 +1188,8 @@ def get_gatherd_vmap_rule(prim, axis_size):
|
|||
mnp.moveaxis(x, x_dim, index_dim)
|
||||
|
||||
# Adapt `dim` to vmap case.
|
||||
dim_value = dim_value + 1 if dim_value >= out_dim else dim_value
|
||||
x_ndim = ops.rank(x)
|
||||
dim_value = _get_reduce_batch_axis(dim_value, x_dim, x_ndim)
|
||||
|
||||
out = prim(x, dim_value, index)
|
||||
return (out, out_dim)
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore.ops import constexpr
|
||||
from ..primitive import Primitive
|
||||
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, _bdim_at_front,\
|
||||
_vmap_clone_prim, _bdim_at_any
|
||||
_vmap_clone_prim, _bdim_at_any
|
||||
|
||||
|
||||
@vmap_rules_getters.register(G.NLLLossGrad)
|
||||
|
@ -67,10 +67,10 @@ def get_nll_loss_grad_vmap_rule(prim, axis_size):
|
|||
if w_dim is not None or tw_dim is not None:
|
||||
_raise_value_error("The source axis of weight and total_weight in `NLLLossGrad` must be None for now, "
|
||||
"but got {} and {}.".format(w_dim, tw_dim))
|
||||
if lg_dim is not None and (base_x_len != 2 or reduction_type != "none"):
|
||||
if lg_dim is not None and reduction_type != "none":
|
||||
_raise_value_error("The source axis of loss_grad in `NLLLossGrad` can be not None "
|
||||
"just when x is 2d and reduction type is none, "
|
||||
"but x is {}d and reduction type is {}.".format(base_x_len, reduction_type))
|
||||
"just when reduction type is none for vmap, "
|
||||
"but reduction type is {}.".format(reduction_type))
|
||||
|
||||
# If stacked, move vmap_dim to first dim and reshape to right shape.
|
||||
if x_dim is not None:
|
||||
|
|
|
@ -2229,12 +2229,13 @@ def gather_elements(x, dim, index):
|
|||
|
||||
Args:
|
||||
x (Tensor): The input tensor.
|
||||
dim (int): The axis along which to index. It must be int32 or int64. The value range is [-x_rank, x_rank).
|
||||
dim (int): The axis along which to index. It must be int32 or int64. The value range is [-x.ndim, x.ndim).
|
||||
index (Tensor): The indices of elements to gather. It can be one of the following data types:
|
||||
int32, int64. The value range of each index element is [-x_rank[dim], x_rank[dim]).
|
||||
int32, int64. The value range of each index element is [-x.shape(dim), x.shape(dim)).
|
||||
|
||||
Returns:
|
||||
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_{n-1})`, has the same data type with `x`.
|
||||
Tensor, has the same shape as index tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_{n-1})`,
|
||||
and has the same data type with `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `dim` or `index` is neither int32 nor int64.
|
||||
|
|
Loading…
Reference in New Issue