!47047 Modify ops.top_k to pytorch

Merge pull request !47047 from 冯一航/modify_top_k_to_pytorch
This commit is contained in:
i-robot 2023-01-16 09:03:31 +00:00 committed by Gitee
commit c52cdf7095
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 163 additions and 36 deletions

View File

@ -1,31 +1,32 @@
mindspore.ops.top_k
===================
.. py:function:: mindspore.ops.top_k(input_x, k, sorted=True)
.. py:function:: mindspore.ops.top_k(input_x, k, dim=None, sorted=True)
沿最后一个维度查找 `k` 个最大元素和对应的索引。
沿给定维度查找 `k` 个最大元素和对应的索引。
.. warning::
- 如果 `sorted` 设置为'False'它将使用aicpu运算符性能可能会降低。
如果 `input_x` 是一维Tensor则查找Tensor中 `k` 个最大元素并将其值和索引输出为Tensor。因此 `values[k]``input_x``k` 个最大元素,其索引是 `indices[k]`
对于多维矩阵,计算每行中最大的 `k` 个元素(沿最后一个维度的相应向量),因此:
对于多维矩阵,计算给定维度中最大的 `k` 个元素,因此:
.. math::
values.shape = indices.shape = input.shape[:-1] + [k].
values.shape = indices.shape
如果两个比较的元素相同,则优先返回索引值较小的元素。
参数:
- **input_x** (Tensor) - 需计算的输入数据类型必须为float16、float32或int32。
- **k** (int) - 指定计算最大元素的数量,需要是常量。
- **dim** (int, 可选) - 需要排序的维度。默认值None。
- **sorted** (bool, 可选) - 如果为True则获取的元素将按值降序排序。默认值True。
返回:
2个Tensor组成的tuple `values``indices`
- **values** (Tensor) - 最后一个维度的每个切片中的 `k` 最大元素。
- **values** (Tensor) - 给定维度的每个切片中的 `k` 最大元素。
- **indices** (Tensor) - `k` 最大元素的对应索引。
异常:

View File

@ -3366,11 +3366,11 @@ def ceil(x):
return F.ceil(x)
def top_k(input_x, k, sorted=True):
def top_k(input_x, k, dim=None, sorted=True):
r"""
For details, please refer to :func:`mindspore.ops.top_k`.
"""
Finds values and indices of the `k` largest entries along the last dimension.
"""
return F.top_k(input_x, k, sorted)
return F.top_k(input_x, k, dim, sorted)
def subtract(x, other, *, alpha=1):

View File

@ -3465,14 +3465,12 @@ class Tensor(Tensor_):
"""
return tensor_operator_registry.get('tile')()(self, multiples)
def top_k(self, k, sorted=True):
def top_k(self, k, dim=None, sorted=True):
r"""
For details, please refer to :func:`mindspore.ops.top_k`.
"""
self._init_check()
validator.check_is_int(k, 'k')
validator.check_bool(sorted, 'sorted')
return tensor_operator_registry.get("top_k")(sorted)(self, k)
return tensor_operator_registry.get("top_k")(self, k, dim, sorted)
def sigmoid(self):
r"""

View File

@ -5612,9 +5612,9 @@ def unsorted_segment_sum(input_x, segment_ids, num_segments):
return unsorted_segment_sum_(input_x, segment_ids, num_segments)
def top_k(input_x, k, sorted=True):
def top_k(input_x, k, dim=None, sorted=True):
r"""
Finds values and indices of the `k` largest entries along the last dimension.
Finds values and indices of the `k` largest entries along a given dimension.
.. warning::
- If sorted is set to 'False', it will use the aicpu operator, the performance may be reduced.
@ -5624,24 +5624,25 @@ def top_k(input_x, k, sorted=True):
and its index is indices [`k`].
For a multi-dimensional matrix,
calculates the first `k` entries in each row (corresponding vector along the last dimension), therefore:
calculates the first `k` entries in a given dimension, therefore:
.. math::
values.shape = indices.shape = input.shape[:-1] + [k].
values.shape = indices.shape
If the two compared elements are the same, the one with the smaller index value is returned first.
Args:
input_x (Tensor): Input to be computed, data type must be float16, float32 or int32.
k (int): The number of top elements to be computed along the last dimension, constant input is needed.
dim (int, optional): The dimension to sort along. Default: None.
sorted (bool, optional): If true, the obtained elements will be sorted by the values in descending order.
Default: True.
Returns:
Tuple of 2 tensors, the values and the indices.
- values (Tensor): The `k` largest elements in each slice of the last dimension.
- values (Tensor): The `k` largest elements in each slice of the given dimension.
- indices (Tensor): The indices of values within the last dimension of input.
Raises:
@ -5654,18 +5655,30 @@ def top_k(input_x, k, sorted=True):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore import Tensor
>>> import mindspore as ms
>>> from mindspore import ops
>>> import mindspore
>>> input_x = Tensor([1, 2, 3, 4, 5], mindspore.float16)
>>> k = 3
>>> values, indices = ops.top_k(input_x, k, sorted=True)
>>> print((values, indices))
(Tensor(shape=[3], dtype=Float16, value= [ 5.0000e+00, 4.0000e+00, 3.0000e+00]), Tensor(shape=[3],
dtype=Int32, value= [4, 3, 2]))
>>> x = ms.Tensor([[0.5368, 0.2447, 0.4302, 0.9673],
... [0.4388, 0.6525, 0.4685, 0.1868],
... [0.3563, 0.5152, 0.9675, 0.8230]], dtype=ms.float32)
>>> output = ops.top_k(x, 2, dim=1)
>>> print(output)
(Tensor(shape=[3, 2], dtype=Float32, value=
[[ 9.67299998e-01, 5.36800027e-01],
[ 6.52499974e-01, 4.68499988e-01],
[ 9.67499971e-01, 8.23000014e-01]]), Tensor(shape=[3, 2], dtype=Int32, value=
[[3, 0],
[1, 2],
[2, 3]]))
"""
top_k_ = _get_cache_prim(P.TopK)(sorted)
return top_k_(input_x, k)
if dim is None or dim == input_x.ndim - 1:
return top_k_(input_x, k)
input_x = input_x.swapaxes(dim, input_x.ndim - 1)
output = top_k_(input_x, k)
values = output[0].swapaxes(dim, input_x.ndim - 1)
indices = output[1].swapaxes(dim, input_x.ndim - 1)
res = (values, indices)
return res
def expand(input_x, size):

View File

@ -352,7 +352,7 @@ tensor_operator_registry.register('coalesce', coalesce)
tensor_operator_registry.register('argmax_with_value', max)
tensor_operator_registry.register('argmin_with_value', min)
tensor_operator_registry.register('coo_add', coo_add)
tensor_operator_registry.register('top_k', P.TopK)
tensor_operator_registry.register('top_k', top_k)
tensor_operator_registry.register('isfinite', P.IsFinite)
tensor_operator_registry.register('to', P.Cast)
tensor_operator_registry.register('bool', P.Cast)

View File

@ -95,24 +95,24 @@ def test_top_k_functional():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
x_np = np.random.rand(3, 4).astype(np.float32)
k = 4
ms_output = F.top_k(Tensor(x_np), k, True)
ms_output = F.top_k(Tensor(x_np), k, sorted=True)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output)
x_np = np.random.rand(3, 4).astype(np.float32)
k = 4
ms_output = F.top_k(Tensor(x_np), k, False)
ms_output = F.top_k(Tensor(x_np), k, sorted=False)
assert np.allclose(ms_output[0].asnumpy(), x_np)
x_np = np.random.rand(2, 3, 4).astype(np.float32)
k = 2
ms_output = F.top_k(Tensor(x_np), k, True)
ms_output = F.top_k(Tensor(x_np), k, sorted=True)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output)
x_np = np.random.rand(512, 1024).astype(np.float32)
k = 512
ms_output = F.top_k(Tensor(x_np), k, True)
ms_output = F.top_k(Tensor(x_np), k, sorted=True)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output)
@ -129,23 +129,23 @@ def test_top_k_tensor():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
x_np = np.random.rand(3, 4).astype(np.float32)
k = 4
ms_output = Tensor(x_np).top_k(k, True)
ms_output = Tensor(x_np).top_k(k, sorted=True)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output)
x_np = np.random.rand(3, 4).astype(np.float32)
k = 4
ms_output = Tensor(x_np).top_k(k, False)
ms_output = Tensor(x_np).top_k(k, sorted=False)
assert np.allclose(ms_output[0].asnumpy(), x_np)
x_np = np.random.rand(2, 3, 4).astype(np.float32)
k = 2
ms_output = Tensor(x_np).top_k(k, True)
ms_output = Tensor(x_np).top_k(k, sorted=True)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output)
x_np = np.random.rand(512, 1024).astype(np.float32)
k = 512
ms_output = Tensor(x_np).top_k(k, True)
ms_output = Tensor(x_np).top_k(k, sorted=True)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output)

View File

@ -0,0 +1,58 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
# pylint: disable=redefined-builtin
def construct(self, input_x, k, dim=None, sorted=True):
output = ops.top_k(input_x, k, dim=dim, sorted=sorted)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_top_k_normal(mode):
"""
Feature: top_k
Description: Verify the result of top_k
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor([[0.5368, 0.2447, 0.4302, 0.9673],
[0.4388, 0.6525, 0.4685, 0.1868],
[0.3563, 0.5152, 0.9675, 0.8230]], dtype=ms.float32)
output = net(x, 2, dim=1)
output0 = output[0]
output1 = output[1]
expect_output0 = np.array([[0.9673, 0.5368],
[0.6525, 0.4685],
[0.9675, 0.823]])
expect_output1 = np.array([[3, 0],
[1, 2],
[2, 3]])
assert np.allclose(output0.asnumpy(), expect_output0)
assert np.allclose(output1.asnumpy(), expect_output1)

View File

@ -0,0 +1,57 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
# pylint: disable=redefined-builtin
def construct(self, input_x, k, dim=None, sorted=True):
output = input_x.top_k(k, dim=dim, sorted=sorted)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_top_k_normal(mode):
"""
Feature: top_k
Description: Verify the result of top_k
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor([[0.5368, 0.2447, 0.4302, 0.9673],
[0.4388, 0.6525, 0.4685, 0.1868],
[0.3563, 0.5152, 0.9675, 0.8230]], dtype=ms.float32)
output = net(x, 2, dim=1)
output0 = output[0]
output1 = output[1]
expect_output0 = np.array([[0.9673, 0.5368],
[0.6525, 0.4685],
[0.9675, 0.823]])
expect_output1 = np.array([[3, 0],
[1, 2],
[2, 3]])
assert np.allclose(output0.asnumpy(), expect_output0)
assert np.allclose(output1.asnumpy(), expect_output1)