forked from mindspore-Ecosystem/mindspore
!47926 modify topk to pytorch(add largest)
Merge pull request !47926 from 冯一航/modify_top_k_topk
This commit is contained in:
commit
fb36b4b482
|
@ -1,16 +1,16 @@
|
|||
mindspore.ops.top_k
|
||||
mindspore.ops.topk
|
||||
===================
|
||||
|
||||
.. py:function:: mindspore.ops.top_k(input_x, k, dim=None, sorted=True)
|
||||
.. py:function:: mindspore.ops.topk(input_x, k, dim=None, largest=True, sorted=True):
|
||||
|
||||
沿给定维度查找 `k` 个最大元素和对应的索引。
|
||||
沿给定维度查找 `k` 个最大或最小元素和对应的索引。
|
||||
|
||||
.. warning::
|
||||
- 如果 `sorted` 设置为'False',它将使用aicpu运算符,性能可能会降低。
|
||||
|
||||
如果 `input_x` 是一维Tensor,则查找Tensor中 `k` 个最大元素,并将其值和索引输出为Tensor。因此, `values[k]` 是 `input_x` 中 `k` 个最大元素,其索引是 `indices[k]` 。
|
||||
如果 `input_x` 是一维Tensor,则查找Tensor中 `k` 个最大或最小元素,并将其值和索引输出为Tensor。因此, `values[k]` 是 `input_x` 中 `k` 个最大元素,其索引是 `indices[k]` 。
|
||||
|
||||
对于多维矩阵,计算给定维度中最大的 `k` 个元素,因此:
|
||||
对于多维矩阵,计算给定维度中最大或最小的 `k` 个元素,因此:
|
||||
|
||||
.. math::
|
||||
values.shape = indices.shape
|
||||
|
@ -19,14 +19,15 @@ mindspore.ops.top_k
|
|||
|
||||
参数:
|
||||
- **input_x** (Tensor) - 需计算的输入,数据类型必须为float16、float32或int32。
|
||||
- **k** (int) - 指定计算最大元素的数量,需要是常量。
|
||||
- **k** (int) - 指定计算最大或最小元素的数量,需要是常量。
|
||||
- **dim** (int, 可选) - 需要排序的维度。默认值:None。
|
||||
- **largest** (bool, 可选) - 如果为False,则会返回前k个最小值。默认值:True。
|
||||
- **sorted** (bool, 可选) - 如果为True,则获取的元素将按值降序排序。默认值:True。
|
||||
|
||||
返回:
|
||||
2个Tensor组成的tuple, `values` 和 `indices` 。
|
||||
|
||||
- **values** (Tensor) - 给定维度的每个切片中的 `k` 最大元素。
|
||||
- **values** (Tensor) - 给定维度的每个切片中的 `k` 最大元素或最小元素。
|
||||
- **indices** (Tensor) - `k` 最大元素的对应索引。
|
||||
|
||||
异常:
|
|
@ -371,7 +371,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"argmax_with_value", std::string("argmax_with_value")}, // P.ArgMaxWithValue
|
||||
{"argmin_with_value", std::string("argmin_with_value")}, // P.ArgMinWithValue
|
||||
{"tile", std::string("tile")}, // P.Tile
|
||||
{"top_k", std::string("top_k")}, // P.TopK()
|
||||
{"topk", std::string("topk")}, // P.TopK()
|
||||
{"isfinite", std::string("isfinite")}, // P.isfinite()
|
||||
{"cos", std::string("cos")}, // cos()
|
||||
{"cov", std::string("cov")}, // cov()
|
||||
|
|
|
@ -3366,11 +3366,11 @@ def ceil(x):
|
|||
return F.ceil(x)
|
||||
|
||||
|
||||
def top_k(input_x, k, dim=None, sorted=True):
|
||||
def topk(input_x, k, dim=None, largest=True, sorted=True):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.top_k`.
|
||||
"""
|
||||
return F.top_k(input_x, k, dim, sorted)
|
||||
return F.topk(input_x, k, dim, largest=largest, sorted=sorted)
|
||||
|
||||
|
||||
def subtract(x, other, *, alpha=1):
|
||||
|
|
|
@ -3465,12 +3465,12 @@ class Tensor(Tensor_):
|
|||
"""
|
||||
return tensor_operator_registry.get('tile')()(self, multiples)
|
||||
|
||||
def top_k(self, k, dim=None, sorted=True):
|
||||
def topk(self, k, dim=None, largest=True, sorted=True):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.top_k`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get("top_k")(self, k, dim, sorted)
|
||||
return tensor_operator_registry.get("topk")(self, k, dim, largest, sorted)
|
||||
|
||||
def sigmoid(self):
|
||||
r"""
|
||||
|
|
|
@ -133,7 +133,7 @@ from .array_func import (
|
|||
argmax,
|
||||
min,
|
||||
population_count,
|
||||
top_k,
|
||||
topk,
|
||||
expand,
|
||||
fold,
|
||||
unfold,
|
||||
|
|
|
@ -5614,19 +5614,19 @@ def unsorted_segment_sum(input_x, segment_ids, num_segments):
|
|||
return unsorted_segment_sum_(input_x, segment_ids, num_segments)
|
||||
|
||||
|
||||
def top_k(input_x, k, dim=None, sorted=True):
|
||||
def topk(input_x, k, dim=None, largest=True, sorted=True):
|
||||
r"""
|
||||
Finds values and indices of the `k` largest entries along a given dimension.
|
||||
Finds values and indices of the `k` largest or smallest entries along a given dimension.
|
||||
|
||||
.. warning::
|
||||
- If sorted is set to 'False', it will use the aicpu operator, the performance may be reduced.
|
||||
|
||||
If the `input_x` is a one-dimensional Tensor, finds the `k` largest entries in the Tensor,
|
||||
If the `input_x` is a one-dimensional Tensor, finds the `k` largest or smallest entries in the Tensor,
|
||||
and outputs its value and index as a Tensor. Therefore, values[`k`] is the `k` largest item in `input_x`,
|
||||
and its index is indices [`k`].
|
||||
|
||||
For a multi-dimensional matrix,
|
||||
calculates the first `k` entries in a given dimension, therefore:
|
||||
calculates the first or last `k` entries in a given dimension, therefore:
|
||||
|
||||
.. math::
|
||||
|
||||
|
@ -5636,15 +5636,16 @@ def top_k(input_x, k, dim=None, sorted=True):
|
|||
|
||||
Args:
|
||||
input_x (Tensor): Input to be computed, data type must be float16, float32 or int32.
|
||||
k (int): The number of top elements to be computed along the last dimension, constant input is needed.
|
||||
k (int): The number of top or bottom elements to be computed along the last dimension, constant input is needed.
|
||||
dim (int, optional): The dimension to sort along. Default: None.
|
||||
largest (bool, optional): If largest is False then the k smallest elements are returned. Default: True.
|
||||
sorted (bool, optional): If true, the obtained elements will be sorted by the values in descending order.
|
||||
Default: True.
|
||||
|
||||
Returns:
|
||||
Tuple of 2 tensors, the values and the indices.
|
||||
|
||||
- values (Tensor): The `k` largest elements in each slice of the given dimension.
|
||||
- values (Tensor): The `k` largest or smallest elements in each slice of the given dimension.
|
||||
- indices (Tensor): The indices of values within the last dimension of input.
|
||||
|
||||
Raises:
|
||||
|
@ -5662,7 +5663,7 @@ def top_k(input_x, k, dim=None, sorted=True):
|
|||
>>> x = ms.Tensor([[0.5368, 0.2447, 0.4302, 0.9673],
|
||||
... [0.4388, 0.6525, 0.4685, 0.1868],
|
||||
... [0.3563, 0.5152, 0.9675, 0.8230]], dtype=ms.float32)
|
||||
>>> output = ops.top_k(x, 2, dim=1)
|
||||
>>> output = ops.topk(x, 2, dim=1)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[3, 2], dtype=Float32, value=
|
||||
[[ 9.67299998e-01, 5.36800027e-01],
|
||||
|
@ -5671,14 +5672,32 @@ def top_k(input_x, k, dim=None, sorted=True):
|
|||
[[3, 0],
|
||||
[1, 2],
|
||||
[2, 3]]))
|
||||
>>> output2 = ops.topk(x, 2, dim=1, largest=False)
|
||||
>>> print(output2)
|
||||
(Tensor(shape=[3, 2], dtype=Float32, value=
|
||||
[[ 2.44700000e-01, 4.30200011e-01],
|
||||
[ 1.86800003e-01, 4.38800007e-01],
|
||||
[ 3.56299996e-01, 5.15200019e-01]]), Tensor(shape=[3, 2], dtype=Int32, value=
|
||||
[[1, 2],
|
||||
[3, 0],
|
||||
[0, 1]]))
|
||||
"""
|
||||
top_k_ = _get_cache_prim(P.TopK)(sorted)
|
||||
if not largest:
|
||||
input_x = -input_x
|
||||
if dim is None or dim == input_x.ndim - 1:
|
||||
if not largest:
|
||||
res = top_k_(input_x, k)
|
||||
values, indices = -res[0], res[1]
|
||||
return values, indices
|
||||
return top_k_(input_x, k)
|
||||
input_x = input_x.swapaxes(dim, input_x.ndim - 1)
|
||||
output = top_k_(input_x, k)
|
||||
values = output[0].swapaxes(dim, input_x.ndim - 1)
|
||||
indices = output[1].swapaxes(dim, input_x.ndim - 1)
|
||||
if not largest:
|
||||
res = (-values, indices)
|
||||
else:
|
||||
res = (values, indices)
|
||||
return res
|
||||
|
||||
|
@ -6471,7 +6490,7 @@ __all__ = [
|
|||
'min',
|
||||
'unsorted_segment_sum',
|
||||
'population_count',
|
||||
'top_k',
|
||||
'topk',
|
||||
'expand',
|
||||
'fold',
|
||||
'unfold',
|
||||
|
|
|
@ -352,7 +352,7 @@ tensor_operator_registry.register('coalesce', coalesce)
|
|||
tensor_operator_registry.register('argmax_with_value', max)
|
||||
tensor_operator_registry.register('argmin_with_value', min)
|
||||
tensor_operator_registry.register('coo_add', coo_add)
|
||||
tensor_operator_registry.register('top_k', top_k)
|
||||
tensor_operator_registry.register('topk', topk)
|
||||
tensor_operator_registry.register('isfinite', P.IsFinite)
|
||||
tensor_operator_registry.register('to', P.Cast)
|
||||
tensor_operator_registry.register('bool', P.Cast)
|
||||
|
|
|
@ -95,24 +95,24 @@ def test_top_k_functional():
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
x_np = np.random.rand(3, 4).astype(np.float32)
|
||||
k = 4
|
||||
ms_output = F.top_k(Tensor(x_np), k, sorted=True)
|
||||
ms_output = F.topk(Tensor(x_np), k, sorted=True)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
x_np = np.random.rand(3, 4).astype(np.float32)
|
||||
k = 4
|
||||
ms_output = F.top_k(Tensor(x_np), k, sorted=False)
|
||||
ms_output = F.topk(Tensor(x_np), k, sorted=False)
|
||||
assert np.allclose(ms_output[0].asnumpy(), x_np)
|
||||
|
||||
x_np = np.random.rand(2, 3, 4).astype(np.float32)
|
||||
k = 2
|
||||
ms_output = F.top_k(Tensor(x_np), k, sorted=True)
|
||||
ms_output = F.topk(Tensor(x_np), k, sorted=True)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
x_np = np.random.rand(512, 1024).astype(np.float32)
|
||||
k = 512
|
||||
ms_output = F.top_k(Tensor(x_np), k, sorted=True)
|
||||
ms_output = F.topk(Tensor(x_np), k, sorted=True)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
|
@ -129,23 +129,23 @@ def test_top_k_tensor():
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
x_np = np.random.rand(3, 4).astype(np.float32)
|
||||
k = 4
|
||||
ms_output = Tensor(x_np).top_k(k, sorted=True)
|
||||
ms_output = Tensor(x_np).topk(k, sorted=True)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
x_np = np.random.rand(3, 4).astype(np.float32)
|
||||
k = 4
|
||||
ms_output = Tensor(x_np).top_k(k, sorted=False)
|
||||
ms_output = Tensor(x_np).topk(k, sorted=False)
|
||||
assert np.allclose(ms_output[0].asnumpy(), x_np)
|
||||
|
||||
x_np = np.random.rand(2, 3, 4).astype(np.float32)
|
||||
k = 2
|
||||
ms_output = Tensor(x_np).top_k(k, sorted=True)
|
||||
ms_output = Tensor(x_np).topk(k, sorted=True)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
x_np = np.random.rand(512, 1024).astype(np.float32)
|
||||
k = 512
|
||||
ms_output = Tensor(x_np).top_k(k, sorted=True)
|
||||
ms_output = Tensor(x_np).topk(k, sorted=True)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
|
|
@ -18,12 +18,13 @@ import pytest
|
|||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
# pylint: disable=redefined-builtin
|
||||
def construct(self, input_x, k, dim=None, sorted=True):
|
||||
output = input_x.top_k(k, dim=dim, sorted=sorted)
|
||||
def construct(self, input_x, k, dim=None, largest=True, sorted=True):
|
||||
output = ops.topk(input_x, k, dim=dim, largest=largest, sorted=sorted)
|
||||
return output
|
||||
|
||||
|
||||
|
@ -31,12 +32,14 @@ class Net(nn.Cell):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_top_k_normal(mode):
|
||||
def test_topk_normal(mode):
|
||||
"""
|
||||
Feature: top_k
|
||||
Description: Verify the result of top_k
|
||||
Feature: topk
|
||||
Description: Verify the result of topk
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
|
@ -53,5 +56,16 @@ def test_top_k_normal(mode):
|
|||
expect_output1 = np.array([[3, 0],
|
||||
[1, 2],
|
||||
[2, 3]])
|
||||
output2 = net(x, 2, dim=1, largest=False)
|
||||
output2_0 = output2[0]
|
||||
output2_1 = output2[1]
|
||||
expect_output2_0 = np.array([[2.44700000e-01, 4.30200011e-01],
|
||||
[1.86800003e-01, 4.38800007e-01],
|
||||
[3.56299996e-01, 5.15200019e-01]])
|
||||
expect_output2_1 = np.array([[1, 2],
|
||||
[3, 0],
|
||||
[0, 1]])
|
||||
assert np.allclose(output0.asnumpy(), expect_output0)
|
||||
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||
assert np.allclose(output2_0.asnumpy(), expect_output2_0)
|
||||
assert np.allclose(output2_1.asnumpy(), expect_output2_1)
|
|
@ -18,13 +18,12 @@ import pytest
|
|||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
# pylint: disable=redefined-builtin
|
||||
def construct(self, input_x, k, dim=None, sorted=True):
|
||||
output = ops.top_k(input_x, k, dim=dim, sorted=sorted)
|
||||
def construct(self, input_x, k, dim=None, largest=True, sorted=True):
|
||||
output = input_x.topk(k, dim=dim, largest=largest, sorted=sorted)
|
||||
return output
|
||||
|
||||
|
||||
|
@ -32,12 +31,14 @@ class Net(nn.Cell):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_top_k_normal(mode):
|
||||
def test_topk_normal(mode):
|
||||
"""
|
||||
Feature: top_k
|
||||
Description: Verify the result of top_k
|
||||
Description: Verify the result of topk
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
|
@ -54,5 +55,16 @@ def test_top_k_normal(mode):
|
|||
expect_output1 = np.array([[3, 0],
|
||||
[1, 2],
|
||||
[2, 3]])
|
||||
output2 = net(x, 2, dim=1, largest=False)
|
||||
output2_0 = output2[0]
|
||||
output2_1 = output2[1]
|
||||
expect_output2_0 = np.array([[2.44700000e-01, 4.30200011e-01],
|
||||
[1.86800003e-01, 4.38800007e-01],
|
||||
[3.56299996e-01, 5.15200019e-01]])
|
||||
expect_output2_1 = np.array([[1, 2],
|
||||
[3, 0],
|
||||
[0, 1]])
|
||||
assert np.allclose(output0.asnumpy(), expect_output0)
|
||||
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||
assert np.allclose(output2_0.asnumpy(), expect_output2_0)
|
||||
assert np.allclose(output2_1.asnumpy(), expect_output2_1)
|
Loading…
Reference in New Issue