!35395 Add functional and tensor interface for `GatherD`

Merge pull request !35395 from TronZhang/gatherd_interface
This commit is contained in:
i-robot 2022-06-10 09:22:38 +00:00 committed by Gitee
commit 0a7e9ac87b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
15 changed files with 253 additions and 13 deletions

View File

@ -310,6 +310,7 @@ Array操作
mindspore.ops.expand_dims
mindspore.ops.gather
mindspore.ops.gather_d
mindspore.ops.gather_elements
mindspore.ops.gather_nd
mindspore.ops.masked_fill
mindspore.ops.masked_select

View File

@ -495,6 +495,39 @@ mindspore.Tensor
与输入的张量具有相同的数据类型的Tensor。
.. py:method:: gather_elements(dim, index)
获取指定轴的元素。
对于三维Tensor输出为
.. code-block::
output[i][j][k] = x[index[i][j][k]][j][k] # if dim == 0
output[i][j][k] = x[i][index[i][j][k]][k] # if dim == 1
output[i][j][k] = x[i][j][index[i][j][k]] # if dim == 2
`index` 与当前Tensor拥有一样的维度长度且除 `dim` 维外其他维度一致。如果维度 `dim` 为i当前Tensor是shape为 :math:`(z_0, z_1, ..., z_i, ..., z_{n-1})` 的n维Tensor`index` 必须是shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})` 的n维Tensor其中 `y` 大于等于1。输出的shape与 `index` 相同。
**参数:**
- **dim** (int) - 获取元素的轴。数据类型为int32或int64。取值范围为[-x_rank, x_rank)。
- **index** (Tensor) - 获取收集元素的索引。支持的数据类型包括int32int64。每个索引元素的取值范围为[-x_rank[dim], x_rank[dim])。
**返回:**
Tensorshape与 `index` 相同即其shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})`,数据类型与 `x` 相同。
**异常:**
- **TypeError** - `dim``index` 的数据类型既不是int32也不是int64。
- **ValueError** - `x``index` 的维度长度不一致。
- **ValueError** - `x``index``dim` 维外的维度不一致。
- **ValueError** - `dim` 的值不在合理范围内。
.. py:method:: gather_nd(indices)
按索引从输入Tensor中获取切片。

View File

@ -15,19 +15,21 @@ mindspore.ops.gather_d
output[i][j][k] = x[i][j][index[i][j][k]] # if dim == 2
如果 `x` 是shape为 :math:`(z_0, z_1, ..., z_i, ..., z_{n-1})` ,维度 `dim` 为i的n维Tensor`index` 必须是shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})` 的n维Tensor其中 `y` 大于等于1输出的shape与 `index` 相同。
`x``index` 拥有一样的维度长度,且除 `dim` 维外其他维度一致。如果维度 `dim` 为i `x` 是shape为 :math:`(z_0, z_1, ..., z_i, ..., z_{n-1})` 的n维Tensor`index` 必须是shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})` 的n维Tensor其中 `y` 大于等于1输出的shape与 `index` 相同。
**参数:**
- **x** (Tensor) - GatherD的输入任意维度的Tensor。
- **dim** (int) - 获取元素的轴。数据类型为int32或int64。只能是常量值
- **dim** (int) - 获取元素的轴。数据类型为int32或int64。取值范围为[-x_rank, x_rank)
- **index** (Tensor) - 获取收集元素的索引。支持的数据类型包括int32int64。每个索引元素的取值范围为[-x_rank[dim], x_rank[dim])。
**返回:**
Tensorshape:math:`(z_1, z_2, ..., z_N)` 的Tensor,数据类型与 `x` 相同。
Tensorshape`index` 相同即其shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})`,数据类型与 `x` 相同。
**异常:**
- **TypeError** - `dim``index` 的数据类型既不是int32也不是int64。
- **ValueError** - `x` 的shape长度不等于 `index` 的shape长度。
- **ValueError** - `x``index` 的维度长度不一致。
- **ValueError** - `x``index``dim` 维外的维度不一致。
- **ValueError** - `dim` 的值不在合理范围内。

View File

@ -0,0 +1,7 @@
mindspore.ops.gather_elements
=============================
.. py:function:: mindspore.ops.gather_elements(x, dim, index)
获取指定轴的元素。
有关更多详细信息,请参阅: py:function:: `mindspore.ops.gather_d`

View File

@ -179,6 +179,7 @@ BuiltInTypeMap &GetMethodMap() {
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
{"gather_elements", std::string("gather_elements")}, // P.GatherD
{"item", std::string("item")}, // P.item,
{"itemset", std::string("itemset")}, // P.itemset,
{"transpose", std::string("transpose")}, // P.transpose

View File

@ -20,6 +20,44 @@
namespace mindspore {
namespace kernel {
namespace {
std::pair<bool, int64_t> GetDimValue(const ValuePtr &dim_value_ptr) {
MS_EXCEPTION_IF_NULL(dim_value_ptr);
int64_t dim_v = 0;
bool value_type_error = false;
if (dim_value_ptr->isa<tensor::Tensor>()) {
auto dim_tensor = dim_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(dim_tensor);
size_t data_size = dim_tensor->DataSize();
MS_EXCEPTION_IF_CHECK_FAIL(data_size == 1, "dim value is not equal to one!");
auto dim_type_id = dim_tensor->data_type();
if (dim_type_id == kNumberTypeInt32) {
auto dim_data32 = reinterpret_cast<int *>(dim_tensor->data_c());
MS_EXCEPTION_IF_NULL(dim_data32);
dim_v = static_cast<int64_t>(*dim_data32);
} else if (dim_type_id == kNumberTypeInt64) {
auto dim_data64 = reinterpret_cast<int64_t *>(dim_tensor->data_c());
MS_EXCEPTION_IF_NULL(dim_data64);
dim_v = static_cast<int64_t>(*dim_data64);
} else {
value_type_error = true;
}
} else {
if (dim_value_ptr->isa<Int32Imm>() || dim_value_ptr->isa<Int64Imm>()) {
dim_v = GetValue<int64_t>(dim_value_ptr);
} else {
value_type_error = true;
}
}
if (value_type_error) {
MS_LOG(ERROR) << "For GatherD, 'dim' must be one of these types: [int32/int64].";
return {false, dim_v};
}
return {true, dim_v};
}
} // namespace
std::vector<std::pair<KernelAttr, GatherFwdGpuKernelMod::GatherFwdFunc>> GatherFwdGpuKernelMod::func_list_ = {
// For static shape case:
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
@ -332,7 +370,12 @@ int GatherFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
if (dim_attr == nullptr) {
return KRET_RESIZE_FAILED;
}
dim_value = GetValue<int64_t>(dim_attr);
auto value_res = GetDimValue(dim_attr);
if (!value_res.first) {
return KRET_RESIZE_FAILED;
}
dim_value = value_res.second;
} else {
GetDynamicAttrIntValue(inputs, 1, inputsOnHost, kernel_name_, &dim_value);
}

View File

@ -33,6 +33,7 @@ int64_t GetGatherDimValue(const AbstractBasePtr dim_ptr) {
auto dim_type_ptr = dim_ptr->BuildType();
MS_EXCEPTION_IF_NULL(dim_type_ptr);
int64_t dim_v = 0;
bool value_type_error = false;
if (dim_value_ptr->isa<tensor::Tensor>()) {
auto dim_tensor = dim_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(dim_tensor);
@ -46,19 +47,25 @@ int64_t GetGatherDimValue(const AbstractBasePtr dim_ptr) {
auto dim_data32 = reinterpret_cast<int *>(dim_tensor->data_c());
MS_EXCEPTION_IF_NULL(dim_data32);
dim_v = static_cast<int64_t>(*dim_data32);
} else {
} else if (element->type_id() == kNumberTypeInt64) {
auto dim_data64 = reinterpret_cast<int64_t *>(dim_tensor->data_c());
MS_EXCEPTION_IF_NULL(dim_data64);
dim_v = static_cast<int64_t>(*dim_data64);
} else {
value_type_error = true;
}
} else {
if (dim_value_ptr->isa<Int32Imm>() || dim_value_ptr->isa<Int64Imm>()) {
dim_v = GetValue<int64_t>(dim_value_ptr);
} else {
MS_LOG(EXCEPTION) << "For GatherD, 'dim' must be one of these types: [int32/int64].";
value_type_error = true;
}
}
if (value_type_error) {
MS_LOG(EXCEPTION) << "For GatherD, 'dim' must be one of these types: [int32/int64].";
}
return dim_v;
}
@ -134,7 +141,9 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
(void)CheckAndConvertUtils::CheckTensorTypeValid("index", input_args[kInputIndex2]->BuildType(), index_valid_types,
prim_name);
(void)CheckAndConvertUtils::CheckSubClass("dim", input_args[kInputIndex1]->BuildType(), dim_valid_types, prim_name);
return abstract::MakeAbstract(GatherDInferShape(primitive, input_args), GatherDInferType(primitive, input_args));
auto infer_type = GatherDInferType(primitive, input_args);
auto infer_shape = GatherDInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true);
} // namespace ops

View File

@ -1361,6 +1361,18 @@ def one_hot(self, depth, on_value, off_value, axis=-1):
return P.OneHot(axis)(self, depth, on_value, off_value)
def gather_elements(x, dim, index):
r"""
Gathers values along an axis specified by dim.
Refer to :func:`mindspore.ops.gather_d` for more detail.
"""
check_value_type('x', x, (Tensor,), 'Tensor.gather_elements')
check_value_type('dim', dim, (int,), 'Tensor.gather_elements')
check_value_type('index', index, (Tensor,), 'Tensor.gather_elements')
return F.gather_elements(x, dim, index)
def sum(x, axis=None, dtype=None, keepdims=False, initial=None): # pylint: disable=redefined-builtin
"""
Return sum of array elements over a given axis.

View File

@ -3741,6 +3741,41 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('masked_select')(self, mask)
def gather_elements(self, dim, index):
"""
Gathers values along an axis specified by dim.
Args:
dim (int): The axis along which to index. It must be int32 or int64.
index (Tensor): The indices of elements to gather. It can be one of the following data types:
int32, int64. The value range of each index element is [-x_rank[dim], x_rank[dim]).
Returns:
Tensor, has the same shape as index tensor and same data type as input tensor.
Raises:
TypeError: If dtype of `dim` or `index` is neither int32 nor int64.
ValueError: If length of shape of `x` is not equal to length of shape of `index`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor
>>> x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.int32)
>>> index = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int32)
>>> dim = 1
>>> output = x.gather_elements(dim, index)
>>> print(output)
[[1 1]
[4 3]]
"""
self._init_check()
validator.check_value_type('index', index, (Tensor_,), 'Tensor.gather_elements')
return tensor_operator_registry.get('gather_elements')(self, dim, index)
def nonzero(self):
"""
Return a tensor of the positions of all non-zero values.

View File

@ -73,20 +73,20 @@ def get_nll_loss_grad_vmap_rule(prim, axis_size):
# If stacked, move vmap_dim to first dim and reshape to right shape.
if x_dim is not None:
if x_dim != 0:
mnp.moveaxis(x, x_dim, 0)
x = mnp.moveaxis(x, x_dim, 0)
x_shape = F.shape(x)
if base_x_len == 2:
x = F.reshape(x, _get_reshape_shape(x_shape, 1))
if lg_dim is not None:
if lg_dim != 0:
mnp.moveaxis(loss_grad, lg_dim, 0)
loss_grad = mnp.moveaxis(loss_grad, lg_dim, 0)
loss_grad_shape = F.shape(loss_grad)
loss_grad = F.reshape(loss_grad, _get_reshape_shape(loss_grad_shape))
if target_dim is not None:
if target_dim != 0:
mnp.moveaxis(target, target_dim, 0)
target = mnp.moveaxis(target, target_dim, 0)
target_shape = F.shape(target)
target = F.reshape(target, _get_reshape_shape(target_shape))

View File

@ -61,6 +61,7 @@ from .array_func import (
scatter_nd_min,
gather,
gather_d,
gather_elements,
gather_nd,
scalar_cast,
masked_fill,

View File

@ -1960,6 +1960,30 @@ def gather_d(x, dim, index):
return gather_d_(x, dim, index)
def gather_elements(x, dim, index):
"""
Gathers values along an axis specified by dim.
Refer to :func:`mindspore.ops.gather_d` for more detail.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor
>>> x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.int32)
>>> index = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int32)
>>> dim = 1
>>> output = mindspore.ops.gather_elements(x, dim, index)
>>> print(output)
[[1 1]
[4 3]]
"""
return gather_d_(x, dim, index)
def gather_nd(input_x, indices):
r"""
Gathers slices from a tensor by indices.
@ -3012,6 +3036,7 @@ __all__ = [
'tensor_scatter_min',
'gather',
'gather_d',
'gather_elements',
'gather_nd',
'one_hot',
'masked_fill',

View File

@ -1001,6 +1001,7 @@ tensor_operator_registry.register('reduce_sum', reduce_sum)
tensor_operator_registry.register('tensor_slice', tensor_slice)
tensor_operator_registry.register('select', select)
tensor_operator_registry.register('gather_d', gather_d)
tensor_operator_registry.register('gather_elements', gather_elements)
tensor_operator_registry.register('gather_nd', gather_nd)
tensor_operator_registry.register('stack', P.Stack)
tensor_operator_registry.register('log', log)

View File

@ -6215,6 +6215,7 @@ class GatherD(Primitive):
def __init__(self):
"""Initialize GatherD"""
self.init_prim_io_names(inputs=['x', 'dim', 'index'], outputs=['output'])
self.set_const_input_indexes([1])
class Identity(Primitive):

View File

@ -19,9 +19,11 @@ import mindspore
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
import mindspore.ops as ops
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.functional import vmap
import mindspore.numpy as ms_np
context.set_context(device_target="Ascend")
@ -37,6 +39,15 @@ class Net(nn.Cell):
return self.op(x, self.dim, index)
class TensorNet(nn.Cell):
def __init__(self, dim=0):
super(TensorNet, self).__init__()
self.dim = dim
def construct(self, x, index):
return x.gather_elements(self.dim, index)
class NetGrad(nn.Cell):
def __init__(self, dim=0, shape=None):
super(NetGrad, self).__init__()
@ -107,6 +118,64 @@ def test_gatherd_dynamic(ms_type):
assert np.array_equal(out.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_functional():
"""
Feature: test GatherD function interface.
Description: input x and index is static shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.int32)
index = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int32)
dim = 1
output = ops.gather_elements(x, dim, index)
expect = np.array([[1, 1], [4, 3]])
assert np.array_equal(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_pynative_tensor():
"""
Feature: test GatherD tensor interface in pynative case.
Description: input x and index is static shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE)
x = ms_np.array([[1, 2], [3, 4]])
dim = 1
index = ms_np.array([[0, 0], [1, 0]])
output = x.gather_elements(dim, index)
expect = np.array([[1, 1], [4, 3]])
assert np.array_equal(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_graph_tensor():
"""
Feature: test GatherD tensor interface in graph case.
Description: input x and index is static shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE)
x = ms_np.array([[1, 2], [3, 4]])
index = ms_np.array([[0, 0], [1, 0]])
dim = 1
net = Net(dim)
output = net(x, index)
expect = np.array([[1, 1], [4, 3]])
assert np.array_equal(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training