!47926 modify topk to pytorch(add largest)

Merge pull request !47926 from 冯一航/modify_top_k_topk
This commit is contained in:
i-robot 2023-01-17 07:39:16 +00:00 committed by Gitee
commit fb36b4b482
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 87 additions and 41 deletions

View File

@ -1,16 +1,16 @@
mindspore.ops.top_k mindspore.ops.topk
=================== ===================
.. py:function:: mindspore.ops.top_k(input_x, k, dim=None, sorted=True) .. py:function:: mindspore.ops.topk(input_x, k, dim=None, largest=True, sorted=True):
沿给定维度查找 `k` 个最大元素和对应的索引。 沿给定维度查找 `k` 个最大或最小元素和对应的索引。
.. warning:: .. warning::
- 如果 `sorted` 设置为'False'它将使用aicpu运算符性能可能会降低。 - 如果 `sorted` 设置为'False'它将使用aicpu运算符性能可能会降低。
如果 `input_x` 是一维Tensor则查找Tensor中 `k` 个最大元素并将其值和索引输出为Tensor。因此 `values[k]``input_x``k` 个最大元素,其索引是 `indices[k]` 如果 `input_x` 是一维Tensor则查找Tensor中 `k` 个最大或最小元素并将其值和索引输出为Tensor。因此 `values[k]``input_x``k` 个最大元素,其索引是 `indices[k]`
对于多维矩阵,计算给定维度中最大的 `k` 个元素,因此: 对于多维矩阵,计算给定维度中最大或最小`k` 个元素,因此:
.. math:: .. math::
values.shape = indices.shape values.shape = indices.shape
@ -19,14 +19,15 @@ mindspore.ops.top_k
参数: 参数:
- **input_x** (Tensor) - 需计算的输入数据类型必须为float16、float32或int32。 - **input_x** (Tensor) - 需计算的输入数据类型必须为float16、float32或int32。
- **k** (int) - 指定计算最大元素的数量,需要是常量。 - **k** (int) - 指定计算最大或最小元素的数量,需要是常量。
- **dim** (int, 可选) - 需要排序的维度。默认值None。 - **dim** (int, 可选) - 需要排序的维度。默认值None。
- **largest** (bool, 可选) - 如果为False则会返回前k个最小值。默认值True。
- **sorted** (bool, 可选) - 如果为True则获取的元素将按值降序排序。默认值True。 - **sorted** (bool, 可选) - 如果为True则获取的元素将按值降序排序。默认值True。
返回: 返回:
2个Tensor组成的tuple `values``indices` 2个Tensor组成的tuple `values``indices`
- **values** (Tensor) - 给定维度的每个切片中的 `k` 最大元素。 - **values** (Tensor) - 给定维度的每个切片中的 `k` 最大元素或最小元素
- **indices** (Tensor) - `k` 最大元素的对应索引。 - **indices** (Tensor) - `k` 最大元素的对应索引。
异常: 异常:

View File

@ -371,7 +371,7 @@ BuiltInTypeMap &GetMethodMap() {
{"argmax_with_value", std::string("argmax_with_value")}, // P.ArgMaxWithValue {"argmax_with_value", std::string("argmax_with_value")}, // P.ArgMaxWithValue
{"argmin_with_value", std::string("argmin_with_value")}, // P.ArgMinWithValue {"argmin_with_value", std::string("argmin_with_value")}, // P.ArgMinWithValue
{"tile", std::string("tile")}, // P.Tile {"tile", std::string("tile")}, // P.Tile
{"top_k", std::string("top_k")}, // P.TopK() {"topk", std::string("topk")}, // P.TopK()
{"isfinite", std::string("isfinite")}, // P.isfinite() {"isfinite", std::string("isfinite")}, // P.isfinite()
{"cos", std::string("cos")}, // cos() {"cos", std::string("cos")}, // cos()
{"cov", std::string("cov")}, // cov() {"cov", std::string("cov")}, // cov()

View File

@ -3366,11 +3366,11 @@ def ceil(x):
return F.ceil(x) return F.ceil(x)
def top_k(input_x, k, dim=None, sorted=True): def topk(input_x, k, dim=None, largest=True, sorted=True):
r""" r"""
For details, please refer to :func:`mindspore.ops.top_k`. For details, please refer to :func:`mindspore.ops.top_k`.
""" """
return F.top_k(input_x, k, dim, sorted) return F.topk(input_x, k, dim, largest=largest, sorted=sorted)
def subtract(x, other, *, alpha=1): def subtract(x, other, *, alpha=1):

View File

@ -3465,12 +3465,12 @@ class Tensor(Tensor_):
""" """
return tensor_operator_registry.get('tile')()(self, multiples) return tensor_operator_registry.get('tile')()(self, multiples)
def top_k(self, k, dim=None, sorted=True): def topk(self, k, dim=None, largest=True, sorted=True):
r""" r"""
For details, please refer to :func:`mindspore.ops.top_k`. For details, please refer to :func:`mindspore.ops.top_k`.
""" """
self._init_check() self._init_check()
return tensor_operator_registry.get("top_k")(self, k, dim, sorted) return tensor_operator_registry.get("topk")(self, k, dim, largest, sorted)
def sigmoid(self): def sigmoid(self):
r""" r"""

View File

@ -133,7 +133,7 @@ from .array_func import (
argmax, argmax,
min, min,
population_count, population_count,
top_k, topk,
expand, expand,
fold, fold,
unfold, unfold,

View File

@ -5614,19 +5614,19 @@ def unsorted_segment_sum(input_x, segment_ids, num_segments):
return unsorted_segment_sum_(input_x, segment_ids, num_segments) return unsorted_segment_sum_(input_x, segment_ids, num_segments)
def top_k(input_x, k, dim=None, sorted=True): def topk(input_x, k, dim=None, largest=True, sorted=True):
r""" r"""
Finds values and indices of the `k` largest entries along a given dimension. Finds values and indices of the `k` largest or smallest entries along a given dimension.
.. warning:: .. warning::
- If sorted is set to 'False', it will use the aicpu operator, the performance may be reduced. - If sorted is set to 'False', it will use the aicpu operator, the performance may be reduced.
If the `input_x` is a one-dimensional Tensor, finds the `k` largest entries in the Tensor, If the `input_x` is a one-dimensional Tensor, finds the `k` largest or smallest entries in the Tensor,
and outputs its value and index as a Tensor. Therefore, values[`k`] is the `k` largest item in `input_x`, and outputs its value and index as a Tensor. Therefore, values[`k`] is the `k` largest item in `input_x`,
and its index is indices [`k`]. and its index is indices [`k`].
For a multi-dimensional matrix, For a multi-dimensional matrix,
calculates the first `k` entries in a given dimension, therefore: calculates the first or last `k` entries in a given dimension, therefore:
.. math:: .. math::
@ -5636,15 +5636,16 @@ def top_k(input_x, k, dim=None, sorted=True):
Args: Args:
input_x (Tensor): Input to be computed, data type must be float16, float32 or int32. input_x (Tensor): Input to be computed, data type must be float16, float32 or int32.
k (int): The number of top elements to be computed along the last dimension, constant input is needed. k (int): The number of top or bottom elements to be computed along the last dimension, constant input is needed.
dim (int, optional): The dimension to sort along. Default: None. dim (int, optional): The dimension to sort along. Default: None.
largest (bool, optional): If largest is False then the k smallest elements are returned. Default: True.
sorted (bool, optional): If true, the obtained elements will be sorted by the values in descending order. sorted (bool, optional): If true, the obtained elements will be sorted by the values in descending order.
Default: True. Default: True.
Returns: Returns:
Tuple of 2 tensors, the values and the indices. Tuple of 2 tensors, the values and the indices.
- values (Tensor): The `k` largest elements in each slice of the given dimension. - values (Tensor): The `k` largest or smallest elements in each slice of the given dimension.
- indices (Tensor): The indices of values within the last dimension of input. - indices (Tensor): The indices of values within the last dimension of input.
Raises: Raises:
@ -5662,7 +5663,7 @@ def top_k(input_x, k, dim=None, sorted=True):
>>> x = ms.Tensor([[0.5368, 0.2447, 0.4302, 0.9673], >>> x = ms.Tensor([[0.5368, 0.2447, 0.4302, 0.9673],
... [0.4388, 0.6525, 0.4685, 0.1868], ... [0.4388, 0.6525, 0.4685, 0.1868],
... [0.3563, 0.5152, 0.9675, 0.8230]], dtype=ms.float32) ... [0.3563, 0.5152, 0.9675, 0.8230]], dtype=ms.float32)
>>> output = ops.top_k(x, 2, dim=1) >>> output = ops.topk(x, 2, dim=1)
>>> print(output) >>> print(output)
(Tensor(shape=[3, 2], dtype=Float32, value= (Tensor(shape=[3, 2], dtype=Float32, value=
[[ 9.67299998e-01, 5.36800027e-01], [[ 9.67299998e-01, 5.36800027e-01],
@ -5671,15 +5672,33 @@ def top_k(input_x, k, dim=None, sorted=True):
[[3, 0], [[3, 0],
[1, 2], [1, 2],
[2, 3]])) [2, 3]]))
>>> output2 = ops.topk(x, 2, dim=1, largest=False)
>>> print(output2)
(Tensor(shape=[3, 2], dtype=Float32, value=
[[ 2.44700000e-01, 4.30200011e-01],
[ 1.86800003e-01, 4.38800007e-01],
[ 3.56299996e-01, 5.15200019e-01]]), Tensor(shape=[3, 2], dtype=Int32, value=
[[1, 2],
[3, 0],
[0, 1]]))
""" """
top_k_ = _get_cache_prim(P.TopK)(sorted) top_k_ = _get_cache_prim(P.TopK)(sorted)
if not largest:
input_x = -input_x
if dim is None or dim == input_x.ndim - 1: if dim is None or dim == input_x.ndim - 1:
if not largest:
res = top_k_(input_x, k)
values, indices = -res[0], res[1]
return values, indices
return top_k_(input_x, k) return top_k_(input_x, k)
input_x = input_x.swapaxes(dim, input_x.ndim - 1) input_x = input_x.swapaxes(dim, input_x.ndim - 1)
output = top_k_(input_x, k) output = top_k_(input_x, k)
values = output[0].swapaxes(dim, input_x.ndim - 1) values = output[0].swapaxes(dim, input_x.ndim - 1)
indices = output[1].swapaxes(dim, input_x.ndim - 1) indices = output[1].swapaxes(dim, input_x.ndim - 1)
res = (values, indices) if not largest:
res = (-values, indices)
else:
res = (values, indices)
return res return res
@ -6471,7 +6490,7 @@ __all__ = [
'min', 'min',
'unsorted_segment_sum', 'unsorted_segment_sum',
'population_count', 'population_count',
'top_k', 'topk',
'expand', 'expand',
'fold', 'fold',
'unfold', 'unfold',

View File

@ -352,7 +352,7 @@ tensor_operator_registry.register('coalesce', coalesce)
tensor_operator_registry.register('argmax_with_value', max) tensor_operator_registry.register('argmax_with_value', max)
tensor_operator_registry.register('argmin_with_value', min) tensor_operator_registry.register('argmin_with_value', min)
tensor_operator_registry.register('coo_add', coo_add) tensor_operator_registry.register('coo_add', coo_add)
tensor_operator_registry.register('top_k', top_k) tensor_operator_registry.register('topk', topk)
tensor_operator_registry.register('isfinite', P.IsFinite) tensor_operator_registry.register('isfinite', P.IsFinite)
tensor_operator_registry.register('to', P.Cast) tensor_operator_registry.register('to', P.Cast)
tensor_operator_registry.register('bool', P.Cast) tensor_operator_registry.register('bool', P.Cast)

View File

@ -95,24 +95,24 @@ def test_top_k_functional():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
x_np = np.random.rand(3, 4).astype(np.float32) x_np = np.random.rand(3, 4).astype(np.float32)
k = 4 k = 4
ms_output = F.top_k(Tensor(x_np), k, sorted=True) ms_output = F.topk(Tensor(x_np), k, sorted=True)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output) assert np.allclose(ms_output[0].asnumpy(), np_output)
x_np = np.random.rand(3, 4).astype(np.float32) x_np = np.random.rand(3, 4).astype(np.float32)
k = 4 k = 4
ms_output = F.top_k(Tensor(x_np), k, sorted=False) ms_output = F.topk(Tensor(x_np), k, sorted=False)
assert np.allclose(ms_output[0].asnumpy(), x_np) assert np.allclose(ms_output[0].asnumpy(), x_np)
x_np = np.random.rand(2, 3, 4).astype(np.float32) x_np = np.random.rand(2, 3, 4).astype(np.float32)
k = 2 k = 2
ms_output = F.top_k(Tensor(x_np), k, sorted=True) ms_output = F.topk(Tensor(x_np), k, sorted=True)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output) assert np.allclose(ms_output[0].asnumpy(), np_output)
x_np = np.random.rand(512, 1024).astype(np.float32) x_np = np.random.rand(512, 1024).astype(np.float32)
k = 512 k = 512
ms_output = F.top_k(Tensor(x_np), k, sorted=True) ms_output = F.topk(Tensor(x_np), k, sorted=True)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output) assert np.allclose(ms_output[0].asnumpy(), np_output)
@ -129,23 +129,23 @@ def test_top_k_tensor():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
x_np = np.random.rand(3, 4).astype(np.float32) x_np = np.random.rand(3, 4).astype(np.float32)
k = 4 k = 4
ms_output = Tensor(x_np).top_k(k, sorted=True) ms_output = Tensor(x_np).topk(k, sorted=True)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output) assert np.allclose(ms_output[0].asnumpy(), np_output)
x_np = np.random.rand(3, 4).astype(np.float32) x_np = np.random.rand(3, 4).astype(np.float32)
k = 4 k = 4
ms_output = Tensor(x_np).top_k(k, sorted=False) ms_output = Tensor(x_np).topk(k, sorted=False)
assert np.allclose(ms_output[0].asnumpy(), x_np) assert np.allclose(ms_output[0].asnumpy(), x_np)
x_np = np.random.rand(2, 3, 4).astype(np.float32) x_np = np.random.rand(2, 3, 4).astype(np.float32)
k = 2 k = 2
ms_output = Tensor(x_np).top_k(k, sorted=True) ms_output = Tensor(x_np).topk(k, sorted=True)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output) assert np.allclose(ms_output[0].asnumpy(), np_output)
x_np = np.random.rand(512, 1024).astype(np.float32) x_np = np.random.rand(512, 1024).astype(np.float32)
k = 512 k = 512
ms_output = Tensor(x_np).top_k(k, sorted=True) ms_output = Tensor(x_np).topk(k, sorted=True)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output) assert np.allclose(ms_output[0].asnumpy(), np_output)

View File

@ -18,12 +18,13 @@ import pytest
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell): class Net(nn.Cell):
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
def construct(self, input_x, k, dim=None, sorted=True): def construct(self, input_x, k, dim=None, largest=True, sorted=True):
output = input_x.top_k(k, dim=dim, sorted=sorted) output = ops.topk(input_x, k, dim=dim, largest=largest, sorted=sorted)
return output return output
@ -31,12 +32,14 @@ class Net(nn.Cell):
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu @pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) @pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_top_k_normal(mode): def test_topk_normal(mode):
""" """
Feature: top_k Feature: topk
Description: Verify the result of top_k Description: Verify the result of topk
Expectation: success Expectation: success
""" """
ms.set_context(mode=mode) ms.set_context(mode=mode)
@ -53,5 +56,16 @@ def test_top_k_normal(mode):
expect_output1 = np.array([[3, 0], expect_output1 = np.array([[3, 0],
[1, 2], [1, 2],
[2, 3]]) [2, 3]])
output2 = net(x, 2, dim=1, largest=False)
output2_0 = output2[0]
output2_1 = output2[1]
expect_output2_0 = np.array([[2.44700000e-01, 4.30200011e-01],
[1.86800003e-01, 4.38800007e-01],
[3.56299996e-01, 5.15200019e-01]])
expect_output2_1 = np.array([[1, 2],
[3, 0],
[0, 1]])
assert np.allclose(output0.asnumpy(), expect_output0) assert np.allclose(output0.asnumpy(), expect_output0)
assert np.allclose(output1.asnumpy(), expect_output1) assert np.allclose(output1.asnumpy(), expect_output1)
assert np.allclose(output2_0.asnumpy(), expect_output2_0)
assert np.allclose(output2_1.asnumpy(), expect_output2_1)

View File

@ -18,13 +18,12 @@ import pytest
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell): class Net(nn.Cell):
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
def construct(self, input_x, k, dim=None, sorted=True): def construct(self, input_x, k, dim=None, largest=True, sorted=True):
output = ops.top_k(input_x, k, dim=dim, sorted=sorted) output = input_x.topk(k, dim=dim, largest=largest, sorted=sorted)
return output return output
@ -32,12 +31,14 @@ class Net(nn.Cell):
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu @pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) @pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_top_k_normal(mode): def test_topk_normal(mode):
""" """
Feature: top_k Feature: top_k
Description: Verify the result of top_k Description: Verify the result of topk
Expectation: success Expectation: success
""" """
ms.set_context(mode=mode) ms.set_context(mode=mode)
@ -54,5 +55,16 @@ def test_top_k_normal(mode):
expect_output1 = np.array([[3, 0], expect_output1 = np.array([[3, 0],
[1, 2], [1, 2],
[2, 3]]) [2, 3]])
output2 = net(x, 2, dim=1, largest=False)
output2_0 = output2[0]
output2_1 = output2[1]
expect_output2_0 = np.array([[2.44700000e-01, 4.30200011e-01],
[1.86800003e-01, 4.38800007e-01],
[3.56299996e-01, 5.15200019e-01]])
expect_output2_1 = np.array([[1, 2],
[3, 0],
[0, 1]])
assert np.allclose(output0.asnumpy(), expect_output0) assert np.allclose(output0.asnumpy(), expect_output0)
assert np.allclose(output1.asnumpy(), expect_output1) assert np.allclose(output1.asnumpy(), expect_output1)
assert np.allclose(output2_0.asnumpy(), expect_output2_0)
assert np.allclose(output2_1.asnumpy(), expect_output2_1)