support gather_v2 tensor and function interface
This commit is contained in:
parent
8c9b466693
commit
33d0690eb6
|
@ -552,7 +552,28 @@ mindspore.Tensor
|
|||
**异常:**
|
||||
|
||||
- **ValueError** - 如果输入Tensor的shape长度小于 `indices` 的最后一个维度。
|
||||
|
||||
.. py:method:: gather(input_indices, axis)
|
||||
|
||||
返回指定 `axis` 上 `input_indices` 的元素对应的输入Tensor切片。为了方便描述,对于输入Tensor记为`input_params`。
|
||||
.. note::
|
||||
1. input_indices 的值必须在`[0, input_params.shape[axis])`的范围内,结果未定义超出范围。
|
||||
2. 当前在Ascend平台,input_params的值不能是 `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`类型。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **input_indices**(Tensor) - 待切片的索引张量,其形状为 :math:`(y_1, y_2, ..., y_S)`,代表指定原始张量元素的索引,其数据类型包括:int32,int64。
|
||||
- **axis**(int) - 指定维度索引的轴以搜集切片。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,其中shape维度为 :math:`input\_params.shape[:axis] + input\_indices.shape + input\_params.shape[axis + 1:]`。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - 如果`axis`不是一个整数。
|
||||
- **TypeError** - 如果`input_indices`不是一个整数类型的Tensor。
|
||||
|
||||
.. py:method:: ger(x)
|
||||
|
||||
计算两个Tensor的外积,即计算此Tensor 和 `x` 的外积。如果此Tensor shape为 :math:`(m,)` ,`x` shape为 :math:`(n,)` ,
|
||||
|
|
|
@ -233,6 +233,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"padding", std::string("padding")}, // padding()
|
||||
{"searchsorted", std::string("searchsorted")}, // P.Select()
|
||||
{"take", std::string("take")}, // P.GatherNd()
|
||||
{"gather", std::string("gather")}, // P.Gather()
|
||||
{"scatter_add", std::string("tensor_scatter_add")}, // P.TensorScatterAdd()
|
||||
{"scatter_mul", std::string("tensor_scatter_mul")}, // tensor_scatter_mul()
|
||||
{"scatter_sub", std::string("tensor_scatter_sub")}, // P.TensorScatterSub()
|
||||
|
|
|
@ -2324,6 +2324,14 @@ def gather_nd(input_x, indices):
|
|||
return F.gather_nd(input_x, indices)
|
||||
|
||||
|
||||
def gather(input_x, input_indices, axis):
|
||||
r"""
|
||||
Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`.
|
||||
Refer to :func:`mindspore.ops.gather` for more detail.
|
||||
"""
|
||||
return F.gather(input_x, input_indices, axis)
|
||||
|
||||
|
||||
def pdist(x, p=2.0):
|
||||
r"""
|
||||
Computes the p-norm distance between each pair of row vectors in the input.
|
||||
|
|
|
@ -1640,8 +1640,6 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('col2im')(self, output_size, kernel_size, dilation, padding_value, stride)
|
||||
|
||||
|
||||
|
||||
def reshape(self, *shape):
|
||||
"""
|
||||
Give a new shape to a tensor without changing its data.
|
||||
|
@ -3366,6 +3364,81 @@ class Tensor(Tensor_):
|
|||
validator.check_value_type('indices', indices, (Tensor_,), 'Tensor.gather_nd')
|
||||
return tensor_operator_registry.get('gather_nd')(self, indices)
|
||||
|
||||
def gather(self, input_indices, axis):
|
||||
r"""
|
||||
Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`.
|
||||
The shape of input tensor is :math:`(x_1, x_2, ..., x_R)`. For convenience, define it as `input_params`
|
||||
|
||||
The following figure shows the calculation process of Gather commonly:
|
||||
|
||||
.. image:: Gather.png
|
||||
|
||||
where params represents the input `input_params`, and indices represents the index to be sliced `input_indices`.
|
||||
|
||||
.. note::
|
||||
1. The value of input_indices must be in the range of `[0, input_param.shape[axis])`, the result
|
||||
is undefined out of range.
|
||||
|
||||
2. The data type of input_params cannot be
|
||||
`bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ on Ascend
|
||||
platform currently.
|
||||
|
||||
Args:
|
||||
input_indices (Tensor): Index tensor to be sliced, the shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
||||
Specifies the indices of elements of the original Tensor. The data type can be int32 or int64.
|
||||
axis (int): Specifies the dimension index to gather indices.
|
||||
|
||||
Returns:
|
||||
Tensor, the shape of tensor is
|
||||
:math:`input\_params.shape[:axis] + input\_indices.shape + input\_params.shape[axis + 1:]`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `axis` is not an int.
|
||||
TypeError: If `input_indices` is not a tensor of type int.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> # case1: input_indices is a Tensor with shape (5, ).
|
||||
>>> input_params = Tensor(np.array([1, 2, 3, 4, 5, 6, 7]), mindspore.float32)
|
||||
>>> input_indices = Tensor(np.array([0, 2, 4, 2, 6]), mindspore.int32)
|
||||
>>> axis = 0
|
||||
>>> output = input_params.gather(input_indices, axis)
|
||||
>>> print(output)
|
||||
[1. 3. 5. 3. 7.]
|
||||
>>> # case2: input_indices is a Tensor with shape (2, 2). When the input_params has one dimension,
|
||||
>>> # the output shape is equal to the input_indices shape.
|
||||
>>> input_indices = Tensor(np.array([[0, 2], [2, 6]]), mindspore.int32)
|
||||
>>> axis = 0
|
||||
>>> output = input_params.gather(input_indices, axis)
|
||||
>>> print(output)
|
||||
[[ 1. 3.]
|
||||
[ 3. 7.]]
|
||||
>>> # case3: input_indices is a Tensor with shape (2, ) and
|
||||
>>> # input_params is a Tensor with shape (3, 4) and axis is 0.
|
||||
>>> input_params = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]), mindspore.float32)
|
||||
>>> input_indices = Tensor(np.array([0, 2]), mindspore.int32)
|
||||
>>> axis = 0
|
||||
>>> output = input_params.gather(input_indices, axis)
|
||||
>>> print(output)
|
||||
[[1. 2. 3. 4.]
|
||||
[9. 10. 11. 12.]]
|
||||
>>> # case4: input_indices is a Tensor with shape (2, ) and
|
||||
>>> # input_params is a Tensor with shape (3, 4) and axis is 1.
|
||||
>>> input_params = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]), mindspore.float32)
|
||||
>>> input_indices = Tensor(np.array([0, 2]), mindspore.int32)
|
||||
>>> axis = 1
|
||||
>>> output = input_params.gather(input_indices, axis)
|
||||
>>> print(output)
|
||||
[[1. 3.]
|
||||
[5. 7.]
|
||||
[9. 11.]]
|
||||
"""
|
||||
self._init_check()
|
||||
validator.check_is_int(axis, 'axis')
|
||||
return tensor_operator_registry.get('gather')(self, input_indices, axis)
|
||||
|
||||
def var(self, axis=None, ddof=0, keepdims=False):
|
||||
"""
|
||||
Compute the variance along the specified axis.
|
||||
|
|
|
@ -998,6 +998,7 @@ tensor_operator_registry.register('eye', eye)
|
|||
tensor_operator_registry.register('reduce_sum', reduce_sum)
|
||||
tensor_operator_registry.register('tensor_slice', tensor_slice)
|
||||
tensor_operator_registry.register('select', select)
|
||||
tensor_operator_registry.register('gather', gather)
|
||||
tensor_operator_registry.register('gather_d', gather_d)
|
||||
tensor_operator_registry.register('gather_elements', gather_elements)
|
||||
tensor_operator_registry.register('gather_nd', gather_nd)
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class GatherTensorTestNet(nn.Cell):
|
||||
def construct(self, x, indices, axis):
|
||||
return x.gather(indices, axis)
|
||||
|
||||
|
||||
class GatherNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GatherNet, self).__init__()
|
||||
self.gather = P.Gather()
|
||||
|
||||
def construct(self, input_x, indices, axis):
|
||||
return self.gather(input_x, indices, axis)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_graph_mode():
|
||||
"""
|
||||
Feature: gather tensor test on graph mode.
|
||||
Description: test gather tensor's interface on graph mode.
|
||||
Expectation: the result equal to expect.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
input_params = Tensor(np.array([1, 2, 3, 4, 5, 6, 7]), mindspore.float32)
|
||||
input_indices = Tensor(np.array([0, 2, 4, 2, 6]), mindspore.int32)
|
||||
axis = 0
|
||||
net = GatherTensorTestNet()
|
||||
output = net(input_params, input_indices, axis)
|
||||
expect_np = np.array([1., 3., 5., 3., 7.])
|
||||
rtol = 1.e-4
|
||||
atol = 1.e-4
|
||||
assert np.allclose(output.asnumpy(), expect_np, rtol, atol, equal_nan=True)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_pynative_mode():
|
||||
"""
|
||||
Feature: gather tensor test on pynative mode.
|
||||
Description: test gather tensor's interface on pynative mode.
|
||||
Expectation: the result equal to expect.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
input_params = Tensor(np.array([1, 2, 3, 4, 5, 6, 7]), mindspore.float32)
|
||||
input_indices = Tensor(np.array([0, 2, 4, 2, 6]), mindspore.int32)
|
||||
axis = 0
|
||||
net = GatherTensorTestNet()
|
||||
output = net(input_params, input_indices, axis)
|
||||
expect_np = np.array([1., 3., 5., 3., 7.])
|
||||
rtol = 1.e-4
|
||||
atol = 1.e-4
|
||||
assert np.allclose(output.asnumpy(), expect_np, rtol, atol, equal_nan=True)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_functional_pynative_mode():
|
||||
"""
|
||||
Feature: gather functional test on pynative mode.
|
||||
Description: test gather_nd functional's interface on pynative mode.
|
||||
Expectation: the result equal to expect.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
input_params = Tensor(np.array([1, 2, 3, 4, 5, 6, 7]), mindspore.float32)
|
||||
input_indices = Tensor(np.array([0, 2, 4, 2, 6]), mindspore.int32)
|
||||
axis = 0
|
||||
output = ops.gather(input_params, input_indices, axis)
|
||||
expect_np = np.array([1., 3., 5., 3., 7.])
|
||||
rtol = 1.e-4
|
||||
atol = 1.e-4
|
||||
assert np.allclose(output.asnumpy(), expect_np, rtol, atol, equal_nan=True)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_functional_graph_mode():
|
||||
"""
|
||||
Feature: gather functional test on graph mode.
|
||||
Description: test gather functional's interface on graph mode.
|
||||
Expectation: the result equal to expect.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
input_params = Tensor(np.array([1, 2, 3, 4, 5, 6, 7]), mindspore.float32)
|
||||
input_indices = Tensor(np.array([0, 2, 4, 2, 6]), mindspore.int32)
|
||||
axis = 0
|
||||
output = ops.gather(input_params, input_indices, axis)
|
||||
expect_np = np.array([1., 3., 5., 3., 7.])
|
||||
rtol = 1.e-4
|
||||
atol = 1.e-4
|
||||
assert np.allclose(output.asnumpy(), expect_np, rtol, atol, equal_nan=True)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather_static_pynative_mode():
|
||||
"""
|
||||
Feature: gather static shape test on pynative mode.
|
||||
Description: test static shape for gather on pynative mode.
|
||||
Expectation: the result equal to expect.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
input_params = Tensor(np.array([1, 2, 3, 4, 5, 6, 7]), mindspore.float32)
|
||||
input_indices = Tensor(np.array([0, 2, 4, 2, 6]), mindspore.int32)
|
||||
axis = 0
|
||||
net = GatherNet()
|
||||
output = net(input_params, input_indices, axis)
|
||||
expect_np = np.array([1., 3., 5., 3., 7.])
|
||||
rtol = 1.e-4
|
||||
atol = 1.e-4
|
||||
assert np.allclose(output.asnumpy(), expect_np, rtol, atol, equal_nan=True)
|
Loading…
Reference in New Issue