forked from mindspore-Ecosystem/mindspore
!35395 Add functional and tensor interface for `GatherD`
Merge pull request !35395 from TronZhang/gatherd_interface
This commit is contained in:
commit
0a7e9ac87b
|
@ -310,6 +310,7 @@ Array操作
|
|||
mindspore.ops.expand_dims
|
||||
mindspore.ops.gather
|
||||
mindspore.ops.gather_d
|
||||
mindspore.ops.gather_elements
|
||||
mindspore.ops.gather_nd
|
||||
mindspore.ops.masked_fill
|
||||
mindspore.ops.masked_select
|
||||
|
|
|
@ -495,6 +495,39 @@ mindspore.Tensor
|
|||
|
||||
与输入的张量具有相同的数据类型的Tensor。
|
||||
|
||||
.. py:method:: gather_elements(dim, index)
|
||||
|
||||
获取指定轴的元素。
|
||||
|
||||
对于三维Tensor,输出为:
|
||||
|
||||
.. code-block::
|
||||
|
||||
output[i][j][k] = x[index[i][j][k]][j][k] # if dim == 0
|
||||
|
||||
output[i][j][k] = x[i][index[i][j][k]][k] # if dim == 1
|
||||
|
||||
output[i][j][k] = x[i][j][index[i][j][k]] # if dim == 2
|
||||
|
||||
`index` 与当前Tensor拥有一样的维度长度,且除 `dim` 维外其他维度一致。如果维度 `dim` 为i,当前Tensor是shape为 :math:`(z_0, z_1, ..., z_i, ..., z_{n-1})` 的n维Tensor,则 `index` 必须是shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})` 的n维Tensor,其中 `y` 大于等于1。输出的shape与 `index` 相同。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **dim** (int) - 获取元素的轴。数据类型为int32或int64。取值范围为[-x_rank, x_rank)。
|
||||
- **index** (Tensor) - 获取收集元素的索引。支持的数据类型包括:int32,int64。每个索引元素的取值范围为[-x_rank[dim], x_rank[dim])。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,shape与 `index` 相同,即其shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})`,数据类型与 `x` 相同。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `dim` 或 `index` 的数据类型既不是int32,也不是int64。
|
||||
- **ValueError** - `x` 和 `index` 的维度长度不一致。
|
||||
- **ValueError** - `x` 和 `index` 除 `dim` 维外的维度不一致。
|
||||
- **ValueError** - `dim` 的值不在合理范围内。
|
||||
|
||||
|
||||
.. py:method:: gather_nd(indices)
|
||||
|
||||
按索引从输入Tensor中获取切片。
|
||||
|
|
|
@ -15,19 +15,21 @@ mindspore.ops.gather_d
|
|||
|
||||
output[i][j][k] = x[i][j][index[i][j][k]] # if dim == 2
|
||||
|
||||
如果 `x` 是shape为 :math:`(z_0, z_1, ..., z_i, ..., z_{n-1})` ,维度 `dim` 为i的n维Tensor,则 `index` 必须是shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})` 的n维Tensor,其中 `y` 大于等于1,输出的shape与 `index` 相同。
|
||||
`x` 与 `index` 拥有一样的维度长度,且除 `dim` 维外其他维度一致。如果维度 `dim` 为i, `x` 是shape为 :math:`(z_0, z_1, ..., z_i, ..., z_{n-1})` 的n维Tensor,则 `index` 必须是shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})` 的n维Tensor,其中 `y` 大于等于1。输出的shape与 `index` 相同。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **x** (Tensor) - GatherD的输入,任意维度的Tensor。
|
||||
- **dim** (int) - 获取元素的轴。数据类型为int32或int64。只能是常量值。
|
||||
- **dim** (int) - 获取元素的轴。数据类型为int32或int64。取值范围为[-x_rank, x_rank)。
|
||||
- **index** (Tensor) - 获取收集元素的索引。支持的数据类型包括:int32,int64。每个索引元素的取值范围为[-x_rank[dim], x_rank[dim])。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,shape为 :math:`(z_1, z_2, ..., z_N)` 的Tensor,数据类型与 `x` 相同。
|
||||
Tensor,shape与 `index` 相同,即其shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})`,数据类型与 `x` 相同。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `dim` 或 `index` 的数据类型既不是int32,也不是int64。
|
||||
- **ValueError** - `x` 的shape长度不等于 `index` 的shape长度。
|
||||
- **ValueError** - `x` 和 `index` 的维度长度不一致。
|
||||
- **ValueError** - `x` 和 `index` 除 `dim` 维外的维度不一致。
|
||||
- **ValueError** - `dim` 的值不在合理范围内。
|
|
@ -0,0 +1,7 @@
|
|||
mindspore.ops.gather_elements
|
||||
=============================
|
||||
|
||||
.. py:function:: mindspore.ops.gather_elements(x, dim, index)
|
||||
|
||||
获取指定轴的元素。
|
||||
有关更多详细信息,请参阅: py:function:: `mindspore.ops.gather_d`。
|
|
@ -179,6 +179,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
|
||||
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
|
||||
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
|
||||
{"gather_elements", std::string("gather_elements")}, // P.GatherD
|
||||
{"item", std::string("item")}, // P.item,
|
||||
{"itemset", std::string("itemset")}, // P.itemset,
|
||||
{"transpose", std::string("transpose")}, // P.transpose
|
||||
|
|
|
@ -20,6 +20,44 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
std::pair<bool, int64_t> GetDimValue(const ValuePtr &dim_value_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(dim_value_ptr);
|
||||
int64_t dim_v = 0;
|
||||
bool value_type_error = false;
|
||||
if (dim_value_ptr->isa<tensor::Tensor>()) {
|
||||
auto dim_tensor = dim_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(dim_tensor);
|
||||
size_t data_size = dim_tensor->DataSize();
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(data_size == 1, "dim value is not equal to one!");
|
||||
auto dim_type_id = dim_tensor->data_type();
|
||||
if (dim_type_id == kNumberTypeInt32) {
|
||||
auto dim_data32 = reinterpret_cast<int *>(dim_tensor->data_c());
|
||||
MS_EXCEPTION_IF_NULL(dim_data32);
|
||||
dim_v = static_cast<int64_t>(*dim_data32);
|
||||
} else if (dim_type_id == kNumberTypeInt64) {
|
||||
auto dim_data64 = reinterpret_cast<int64_t *>(dim_tensor->data_c());
|
||||
MS_EXCEPTION_IF_NULL(dim_data64);
|
||||
dim_v = static_cast<int64_t>(*dim_data64);
|
||||
} else {
|
||||
value_type_error = true;
|
||||
}
|
||||
} else {
|
||||
if (dim_value_ptr->isa<Int32Imm>() || dim_value_ptr->isa<Int64Imm>()) {
|
||||
dim_v = GetValue<int64_t>(dim_value_ptr);
|
||||
} else {
|
||||
value_type_error = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (value_type_error) {
|
||||
MS_LOG(ERROR) << "For GatherD, 'dim' must be one of these types: [int32/int64].";
|
||||
return {false, dim_v};
|
||||
}
|
||||
return {true, dim_v};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<std::pair<KernelAttr, GatherFwdGpuKernelMod::GatherFwdFunc>> GatherFwdGpuKernelMod::func_list_ = {
|
||||
// For static shape case:
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
|
@ -332,7 +370,12 @@ int GatherFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
|
|||
if (dim_attr == nullptr) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
dim_value = GetValue<int64_t>(dim_attr);
|
||||
|
||||
auto value_res = GetDimValue(dim_attr);
|
||||
if (!value_res.first) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
dim_value = value_res.second;
|
||||
} else {
|
||||
GetDynamicAttrIntValue(inputs, 1, inputsOnHost, kernel_name_, &dim_value);
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@ int64_t GetGatherDimValue(const AbstractBasePtr dim_ptr) {
|
|||
auto dim_type_ptr = dim_ptr->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(dim_type_ptr);
|
||||
int64_t dim_v = 0;
|
||||
bool value_type_error = false;
|
||||
if (dim_value_ptr->isa<tensor::Tensor>()) {
|
||||
auto dim_tensor = dim_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(dim_tensor);
|
||||
|
@ -46,19 +47,25 @@ int64_t GetGatherDimValue(const AbstractBasePtr dim_ptr) {
|
|||
auto dim_data32 = reinterpret_cast<int *>(dim_tensor->data_c());
|
||||
MS_EXCEPTION_IF_NULL(dim_data32);
|
||||
dim_v = static_cast<int64_t>(*dim_data32);
|
||||
} else {
|
||||
} else if (element->type_id() == kNumberTypeInt64) {
|
||||
auto dim_data64 = reinterpret_cast<int64_t *>(dim_tensor->data_c());
|
||||
MS_EXCEPTION_IF_NULL(dim_data64);
|
||||
dim_v = static_cast<int64_t>(*dim_data64);
|
||||
} else {
|
||||
value_type_error = true;
|
||||
}
|
||||
} else {
|
||||
if (dim_value_ptr->isa<Int32Imm>() || dim_value_ptr->isa<Int64Imm>()) {
|
||||
dim_v = GetValue<int64_t>(dim_value_ptr);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For GatherD, 'dim' must be one of these types: [int32/int64].";
|
||||
value_type_error = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (value_type_error) {
|
||||
MS_LOG(EXCEPTION) << "For GatherD, 'dim' must be one of these types: [int32/int64].";
|
||||
}
|
||||
|
||||
return dim_v;
|
||||
}
|
||||
|
||||
|
@ -134,7 +141,9 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
(void)CheckAndConvertUtils::CheckTensorTypeValid("index", input_args[kInputIndex2]->BuildType(), index_valid_types,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckSubClass("dim", input_args[kInputIndex1]->BuildType(), dim_valid_types, prim_name);
|
||||
return abstract::MakeAbstract(GatherDInferShape(primitive, input_args), GatherDInferType(primitive, input_args));
|
||||
auto infer_type = GatherDInferType(primitive, input_args);
|
||||
auto infer_shape = GatherDInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -1361,6 +1361,18 @@ def one_hot(self, depth, on_value, off_value, axis=-1):
|
|||
return P.OneHot(axis)(self, depth, on_value, off_value)
|
||||
|
||||
|
||||
def gather_elements(x, dim, index):
|
||||
r"""
|
||||
Gathers values along an axis specified by dim.
|
||||
|
||||
Refer to :func:`mindspore.ops.gather_d` for more detail.
|
||||
"""
|
||||
check_value_type('x', x, (Tensor,), 'Tensor.gather_elements')
|
||||
check_value_type('dim', dim, (int,), 'Tensor.gather_elements')
|
||||
check_value_type('index', index, (Tensor,), 'Tensor.gather_elements')
|
||||
return F.gather_elements(x, dim, index)
|
||||
|
||||
|
||||
def sum(x, axis=None, dtype=None, keepdims=False, initial=None): # pylint: disable=redefined-builtin
|
||||
"""
|
||||
Return sum of array elements over a given axis.
|
||||
|
|
|
@ -3741,6 +3741,41 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('masked_select')(self, mask)
|
||||
|
||||
def gather_elements(self, dim, index):
|
||||
"""
|
||||
Gathers values along an axis specified by dim.
|
||||
|
||||
Args:
|
||||
dim (int): The axis along which to index. It must be int32 or int64.
|
||||
index (Tensor): The indices of elements to gather. It can be one of the following data types:
|
||||
int32, int64. The value range of each index element is [-x_rank[dim], x_rank[dim]).
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape as index tensor and same data type as input tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `dim` or `index` is neither int32 nor int64.
|
||||
ValueError: If length of shape of `x` is not equal to length of shape of `index`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore
|
||||
>>> from mindspore import Tensor
|
||||
>>> x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.int32)
|
||||
>>> index = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int32)
|
||||
>>> dim = 1
|
||||
>>> output = x.gather_elements(dim, index)
|
||||
>>> print(output)
|
||||
[[1 1]
|
||||
[4 3]]
|
||||
"""
|
||||
self._init_check()
|
||||
validator.check_value_type('index', index, (Tensor_,), 'Tensor.gather_elements')
|
||||
return tensor_operator_registry.get('gather_elements')(self, dim, index)
|
||||
|
||||
def nonzero(self):
|
||||
"""
|
||||
Return a tensor of the positions of all non-zero values.
|
||||
|
|
|
@ -73,20 +73,20 @@ def get_nll_loss_grad_vmap_rule(prim, axis_size):
|
|||
# If stacked, move vmap_dim to first dim and reshape to right shape.
|
||||
if x_dim is not None:
|
||||
if x_dim != 0:
|
||||
mnp.moveaxis(x, x_dim, 0)
|
||||
x = mnp.moveaxis(x, x_dim, 0)
|
||||
x_shape = F.shape(x)
|
||||
if base_x_len == 2:
|
||||
x = F.reshape(x, _get_reshape_shape(x_shape, 1))
|
||||
|
||||
if lg_dim is not None:
|
||||
if lg_dim != 0:
|
||||
mnp.moveaxis(loss_grad, lg_dim, 0)
|
||||
loss_grad = mnp.moveaxis(loss_grad, lg_dim, 0)
|
||||
loss_grad_shape = F.shape(loss_grad)
|
||||
loss_grad = F.reshape(loss_grad, _get_reshape_shape(loss_grad_shape))
|
||||
|
||||
if target_dim is not None:
|
||||
if target_dim != 0:
|
||||
mnp.moveaxis(target, target_dim, 0)
|
||||
target = mnp.moveaxis(target, target_dim, 0)
|
||||
target_shape = F.shape(target)
|
||||
target = F.reshape(target, _get_reshape_shape(target_shape))
|
||||
|
||||
|
|
|
@ -61,6 +61,7 @@ from .array_func import (
|
|||
scatter_nd_min,
|
||||
gather,
|
||||
gather_d,
|
||||
gather_elements,
|
||||
gather_nd,
|
||||
scalar_cast,
|
||||
masked_fill,
|
||||
|
|
|
@ -1960,6 +1960,30 @@ def gather_d(x, dim, index):
|
|||
return gather_d_(x, dim, index)
|
||||
|
||||
|
||||
def gather_elements(x, dim, index):
|
||||
"""
|
||||
Gathers values along an axis specified by dim.
|
||||
|
||||
Refer to :func:`mindspore.ops.gather_d` for more detail.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore
|
||||
>>> from mindspore import Tensor
|
||||
>>> x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.int32)
|
||||
>>> index = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int32)
|
||||
>>> dim = 1
|
||||
>>> output = mindspore.ops.gather_elements(x, dim, index)
|
||||
>>> print(output)
|
||||
[[1 1]
|
||||
[4 3]]
|
||||
"""
|
||||
return gather_d_(x, dim, index)
|
||||
|
||||
|
||||
def gather_nd(input_x, indices):
|
||||
r"""
|
||||
Gathers slices from a tensor by indices.
|
||||
|
@ -3012,6 +3036,7 @@ __all__ = [
|
|||
'tensor_scatter_min',
|
||||
'gather',
|
||||
'gather_d',
|
||||
'gather_elements',
|
||||
'gather_nd',
|
||||
'one_hot',
|
||||
'masked_fill',
|
||||
|
|
|
@ -1001,6 +1001,7 @@ tensor_operator_registry.register('reduce_sum', reduce_sum)
|
|||
tensor_operator_registry.register('tensor_slice', tensor_slice)
|
||||
tensor_operator_registry.register('select', select)
|
||||
tensor_operator_registry.register('gather_d', gather_d)
|
||||
tensor_operator_registry.register('gather_elements', gather_elements)
|
||||
tensor_operator_registry.register('gather_nd', gather_nd)
|
||||
tensor_operator_registry.register('stack', P.Stack)
|
||||
tensor_operator_registry.register('log', log)
|
||||
|
|
|
@ -6215,6 +6215,7 @@ class GatherD(Primitive):
|
|||
def __init__(self):
|
||||
"""Initialize GatherD"""
|
||||
self.init_prim_io_names(inputs=['x', 'dim', 'index'], outputs=['output'])
|
||||
self.set_const_input_indexes([1])
|
||||
|
||||
|
||||
class Identity(Primitive):
|
||||
|
|
|
@ -19,9 +19,11 @@ import mindspore
|
|||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
import mindspore.ops as ops
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops.functional import vmap
|
||||
import mindspore.numpy as ms_np
|
||||
|
||||
|
||||
context.set_context(device_target="Ascend")
|
||||
|
@ -37,6 +39,15 @@ class Net(nn.Cell):
|
|||
return self.op(x, self.dim, index)
|
||||
|
||||
|
||||
class TensorNet(nn.Cell):
|
||||
def __init__(self, dim=0):
|
||||
super(TensorNet, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
def construct(self, x, index):
|
||||
return x.gather_elements(self.dim, index)
|
||||
|
||||
|
||||
class NetGrad(nn.Cell):
|
||||
def __init__(self, dim=0, shape=None):
|
||||
super(NetGrad, self).__init__()
|
||||
|
@ -107,6 +118,64 @@ def test_gatherd_dynamic(ms_type):
|
|||
assert np.array_equal(out.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_functional():
|
||||
"""
|
||||
Feature: test GatherD function interface.
|
||||
Description: input x and index is static shape.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.int32)
|
||||
index = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int32)
|
||||
dim = 1
|
||||
output = ops.gather_elements(x, dim, index)
|
||||
expect = np.array([[1, 1], [4, 3]])
|
||||
assert np.array_equal(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_tensor():
|
||||
"""
|
||||
Feature: test GatherD tensor interface in pynative case.
|
||||
Description: input x and index is static shape.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
x = ms_np.array([[1, 2], [3, 4]])
|
||||
dim = 1
|
||||
index = ms_np.array([[0, 0], [1, 0]])
|
||||
output = x.gather_elements(dim, index)
|
||||
expect = np.array([[1, 1], [4, 3]])
|
||||
assert np.array_equal(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_graph_tensor():
|
||||
"""
|
||||
Feature: test GatherD tensor interface in graph case.
|
||||
Description: input x and index is static shape.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = ms_np.array([[1, 2], [3, 4]])
|
||||
index = ms_np.array([[0, 0], [1, 0]])
|
||||
dim = 1
|
||||
net = Net(dim)
|
||||
output = net(x, index)
|
||||
expect = np.array([[1, 1], [4, 3]])
|
||||
assert np.array_equal(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
|
Loading…
Reference in New Issue