diff --git a/docs/api/api_python/mindspore/mindspore.Tensor.rst b/docs/api/api_python/mindspore/mindspore.Tensor.rst index 726dd7910b5..f942eb3dd51 100644 --- a/docs/api/api_python/mindspore/mindspore.Tensor.rst +++ b/docs/api/api_python/mindspore/mindspore.Tensor.rst @@ -552,7 +552,28 @@ mindspore.Tensor **异常:** - **ValueError** - 如果输入Tensor的shape长度小于 `indices` 的最后一个维度。 + + .. py:method:: gather(input_indices, axis) + 返回指定 `axis` 上 `input_indices` 的元素对应的输入Tensor切片。为了方便描述,对于输入Tensor记为`input_params`。 + .. note:: + 1. input_indices 的值必须在`[0, input_params.shape[axis])`的范围内,结果未定义超出范围。 + 2. 当前在Ascend平台,input_params的值不能是 `bool_ `类型。 + + **参数:** + + - **input_indices**(Tensor) - 待切片的索引张量,其形状为 :math:`(y_1, y_2, ..., y_S)`,代表指定原始张量元素的索引,其数据类型包括:int32,int64。 + - **axis**(int) - 指定维度索引的轴以搜集切片。 + + **返回:** + + Tensor,其中shape维度为 :math:`input\_params.shape[:axis] + input\_indices.shape + input\_params.shape[axis + 1:]`。 + + **异常:** + + - **TypeError** - 如果`axis`不是一个整数。 + - **TypeError** - 如果`input_indices`不是一个整数类型的Tensor。 + .. py:method:: ger(x) 计算两个Tensor的外积,即计算此Tensor 和 `x` 的外积。如果此Tensor shape为 :math:`(m,)` ,`x` shape为 :math:`(n,)` , diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index 8abfdce79c3..736ecd9c32a 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -233,6 +233,7 @@ BuiltInTypeMap &GetMethodMap() { {"padding", std::string("padding")}, // padding() {"searchsorted", std::string("searchsorted")}, // P.Select() {"take", std::string("take")}, // P.GatherNd() + {"gather", std::string("gather")}, // P.Gather() {"scatter_add", std::string("tensor_scatter_add")}, // P.TensorScatterAdd() {"scatter_mul", std::string("tensor_scatter_mul")}, // tensor_scatter_mul() {"scatter_sub", std::string("tensor_scatter_sub")}, // P.TensorScatterSub() diff --git a/mindspore/python/mindspore/_extends/parse/standard_method.py b/mindspore/python/mindspore/_extends/parse/standard_method.py index eef0137eb38..fb7c582d159 100644 --- a/mindspore/python/mindspore/_extends/parse/standard_method.py +++ b/mindspore/python/mindspore/_extends/parse/standard_method.py @@ -2324,6 +2324,14 @@ def gather_nd(input_x, indices): return F.gather_nd(input_x, indices) +def gather(input_x, input_indices, axis): + r""" + Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`. + Refer to :func:`mindspore.ops.gather` for more detail. + """ + return F.gather(input_x, input_indices, axis) + + def pdist(x, p=2.0): r""" Computes the p-norm distance between each pair of row vectors in the input. diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index 84dd9dde4bc..0d51851a927 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -1640,8 +1640,6 @@ class Tensor(Tensor_): self._init_check() return tensor_operator_registry.get('col2im')(self, output_size, kernel_size, dilation, padding_value, stride) - - def reshape(self, *shape): """ Give a new shape to a tensor without changing its data. @@ -3366,6 +3364,81 @@ class Tensor(Tensor_): validator.check_value_type('indices', indices, (Tensor_,), 'Tensor.gather_nd') return tensor_operator_registry.get('gather_nd')(self, indices) + def gather(self, input_indices, axis): + r""" + Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`. + The shape of input tensor is :math:`(x_1, x_2, ..., x_R)`. For convenience, define it as `input_params` + + The following figure shows the calculation process of Gather commonly: + + .. image:: Gather.png + + where params represents the input `input_params`, and indices represents the index to be sliced `input_indices`. + + .. note:: + 1. The value of input_indices must be in the range of `[0, input_param.shape[axis])`, the result + is undefined out of range. + + 2. The data type of input_params cannot be + `bool_ `_ on Ascend + platform currently. + + Args: + input_indices (Tensor): Index tensor to be sliced, the shape of tensor is :math:`(y_1, y_2, ..., y_S)`. + Specifies the indices of elements of the original Tensor. The data type can be int32 or int64. + axis (int): Specifies the dimension index to gather indices. + + Returns: + Tensor, the shape of tensor is + :math:`input\_params.shape[:axis] + input\_indices.shape + input\_params.shape[axis + 1:]`. + + Raises: + TypeError: If `axis` is not an int. + TypeError: If `input_indices` is not a tensor of type int. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> # case1: input_indices is a Tensor with shape (5, ). + >>> input_params = Tensor(np.array([1, 2, 3, 4, 5, 6, 7]), mindspore.float32) + >>> input_indices = Tensor(np.array([0, 2, 4, 2, 6]), mindspore.int32) + >>> axis = 0 + >>> output = input_params.gather(input_indices, axis) + >>> print(output) + [1. 3. 5. 3. 7.] + >>> # case2: input_indices is a Tensor with shape (2, 2). When the input_params has one dimension, + >>> # the output shape is equal to the input_indices shape. + >>> input_indices = Tensor(np.array([[0, 2], [2, 6]]), mindspore.int32) + >>> axis = 0 + >>> output = input_params.gather(input_indices, axis) + >>> print(output) + [[ 1. 3.] + [ 3. 7.]] + >>> # case3: input_indices is a Tensor with shape (2, ) and + >>> # input_params is a Tensor with shape (3, 4) and axis is 0. + >>> input_params = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]), mindspore.float32) + >>> input_indices = Tensor(np.array([0, 2]), mindspore.int32) + >>> axis = 0 + >>> output = input_params.gather(input_indices, axis) + >>> print(output) + [[1. 2. 3. 4.] + [9. 10. 11. 12.]] + >>> # case4: input_indices is a Tensor with shape (2, ) and + >>> # input_params is a Tensor with shape (3, 4) and axis is 1. + >>> input_params = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]), mindspore.float32) + >>> input_indices = Tensor(np.array([0, 2]), mindspore.int32) + >>> axis = 1 + >>> output = input_params.gather(input_indices, axis) + >>> print(output) + [[1. 3.] + [5. 7.] + [9. 11.]] + """ + self._init_check() + validator.check_is_int(axis, 'axis') + return tensor_operator_registry.get('gather')(self, input_indices, axis) + def var(self, axis=None, ddof=0, keepdims=False): """ Compute the variance along the specified axis. diff --git a/mindspore/python/mindspore/ops/functional.py b/mindspore/python/mindspore/ops/functional.py index 3b1f7f2b65a..82e2bcacbe1 100644 --- a/mindspore/python/mindspore/ops/functional.py +++ b/mindspore/python/mindspore/ops/functional.py @@ -998,6 +998,7 @@ tensor_operator_registry.register('eye', eye) tensor_operator_registry.register('reduce_sum', reduce_sum) tensor_operator_registry.register('tensor_slice', tensor_slice) tensor_operator_registry.register('select', select) +tensor_operator_registry.register('gather', gather) tensor_operator_registry.register('gather_d', gather_d) tensor_operator_registry.register('gather_elements', gather_elements) tensor_operator_registry.register('gather_nd', gather_nd) diff --git a/tests/st/ops/ascend/test_tbe_ops/test_gather.py b/tests/st/ops/ascend/test_tbe_ops/test_gather.py new file mode 100644 index 00000000000..7b73c9874f5 --- /dev/null +++ b/tests/st/ops/ascend/test_tbe_ops/test_gather.py @@ -0,0 +1,144 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.ops import operations as P + + +class GatherTensorTestNet(nn.Cell): + def construct(self, x, indices, axis): + return x.gather(indices, axis) + + +class GatherNet(nn.Cell): + def __init__(self): + super(GatherNet, self).__init__() + self.gather = P.Gather() + + def construct(self, input_x, indices, axis): + return self.gather(input_x, indices, axis) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_tensor_graph_mode(): + """ + Feature: gather tensor test on graph mode. + Description: test gather tensor's interface on graph mode. + Expectation: the result equal to expect. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + input_params = Tensor(np.array([1, 2, 3, 4, 5, 6, 7]), mindspore.float32) + input_indices = Tensor(np.array([0, 2, 4, 2, 6]), mindspore.int32) + axis = 0 + net = GatherTensorTestNet() + output = net(input_params, input_indices, axis) + expect_np = np.array([1., 3., 5., 3., 7.]) + rtol = 1.e-4 + atol = 1.e-4 + assert np.allclose(output.asnumpy(), expect_np, rtol, atol, equal_nan=True) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_tensor_pynative_mode(): + """ + Feature: gather tensor test on pynative mode. + Description: test gather tensor's interface on pynative mode. + Expectation: the result equal to expect. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + input_params = Tensor(np.array([1, 2, 3, 4, 5, 6, 7]), mindspore.float32) + input_indices = Tensor(np.array([0, 2, 4, 2, 6]), mindspore.int32) + axis = 0 + net = GatherTensorTestNet() + output = net(input_params, input_indices, axis) + expect_np = np.array([1., 3., 5., 3., 7.]) + rtol = 1.e-4 + atol = 1.e-4 + assert np.allclose(output.asnumpy(), expect_np, rtol, atol, equal_nan=True) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_functional_pynative_mode(): + """ + Feature: gather functional test on pynative mode. + Description: test gather_nd functional's interface on pynative mode. + Expectation: the result equal to expect. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + input_params = Tensor(np.array([1, 2, 3, 4, 5, 6, 7]), mindspore.float32) + input_indices = Tensor(np.array([0, 2, 4, 2, 6]), mindspore.int32) + axis = 0 + output = ops.gather(input_params, input_indices, axis) + expect_np = np.array([1., 3., 5., 3., 7.]) + rtol = 1.e-4 + atol = 1.e-4 + assert np.allclose(output.asnumpy(), expect_np, rtol, atol, equal_nan=True) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_functional_graph_mode(): + """ + Feature: gather functional test on graph mode. + Description: test gather functional's interface on graph mode. + Expectation: the result equal to expect. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + input_params = Tensor(np.array([1, 2, 3, 4, 5, 6, 7]), mindspore.float32) + input_indices = Tensor(np.array([0, 2, 4, 2, 6]), mindspore.int32) + axis = 0 + output = ops.gather(input_params, input_indices, axis) + expect_np = np.array([1., 3., 5., 3., 7.]) + rtol = 1.e-4 + atol = 1.e-4 + assert np.allclose(output.asnumpy(), expect_np, rtol, atol, equal_nan=True) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_gather_static_pynative_mode(): + """ + Feature: gather static shape test on pynative mode. + Description: test static shape for gather on pynative mode. + Expectation: the result equal to expect. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + input_params = Tensor(np.array([1, 2, 3, 4, 5, 6, 7]), mindspore.float32) + input_indices = Tensor(np.array([0, 2, 4, 2, 6]), mindspore.int32) + axis = 0 + net = GatherNet() + output = net(input_params, input_indices, axis) + expect_np = np.array([1., 3., 5., 3., 7.]) + rtol = 1.e-4 + atol = 1.e-4 + assert np.allclose(output.asnumpy(), expect_np, rtol, atol, equal_nan=True)