impl review suggestions for PR: 62636 add ST for topk: support k with mutable Tensor type mpl review suggestions for PR: 63268

This commit is contained in:
lilinjie 2024-03-04 10:44:21 +08:00 committed by hedongdong
parent 772542e148
commit 9c3e963dbe
7 changed files with 60 additions and 7 deletions

View File

@ -27,7 +27,7 @@
- GPUfloat16、float32。
- CPU所有数值型。
- **k** (int) - 指定计算最大元素的数量,必须为常量
- **k** (Union(Tensor, int)) - 指定计算最大元素的数量。若 `k` 为Tensor其数据类型须为int32。若为Tensor只支持零维Tensor或shape为 :math:`(1, )` 的一维Tensor
输出:
`values``indices` 组成的tuple。

View File

@ -19,7 +19,7 @@ mindspore.ops.topk
参数:
- **input** (Tensor) - 需计算的输入数据类型必须为float16、float32或int32。
- **k** (int) - 指定计算最大或最小元素的数量,必须为常量
- **k** (int) - 指定计算最大或最小元素的数量。
- **dim** (int, 可选) - 需要排序的维度。默认值: ``None``
- **largest** (bool, 可选) - 如果为 ``False`` 则会返回前k个最小值。默认值 ``True``
- **sorted** (bool, 可选) - 如果为 ``True`` ,则获取的元素将按值降序排序。如果为 ``False`` ,则不对获取的元素进行排序。默认值: ``True``

View File

@ -694,7 +694,7 @@ template <typename T, typename S>
void ReluBool(ArithmeticSelfCpuKernelFuncBool<T, S> *content, const T *in, S *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = in[i] ? true : false;
out[i] = in[i];
}
};
ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_);

View File

@ -52,6 +52,12 @@ abstract::TupleShapePtr TopKInferShape(const PrimitivePtr &primitive, const std:
// 2rd input is a Tensor when TopK is a dynamic shape operator
if (CheckAndConvertUtils::IsTensor(input_args[kInputIndex1])) {
auto k_dim = input_args[kInputIndex1]->GetShape()->GetShapeVector().size();
if (k_dim > 1) {
MS_LOG(EXCEPTION) << "For '" << prim_name
<< "', the dimension of 'k' should only be 0 or 1 when 'k' is a Tensor, but got: " << k_dim
<< ".";
}
auto k_val = CheckAndConvertUtils::CheckTensorIntValue("k", input_args[kInputIndex1]->GetValue(), prim_name,
input_args[kInputIndex1]->GetType());
k_v = k_val[0];
@ -78,7 +84,7 @@ TuplePtr TopKInferType(const PrimitivePtr &primitive, const std::vector<Abstract
auto output0_type = input_args[kInputIndex0]->GetType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", output0_type, common_valid_types, prim_name);
auto k_type = input_args[kInputIndex1]->GetType();
const std::set<TypePtr> int_types = {kInt8, kInt16, kInt32, kInt64};
const std::set<TypePtr> int_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTypeValid("k", k_type, int_types, prim_name);
auto output1_type = kInt32;
return std::make_shared<Tuple>(std::vector<TypePtr>{output0_type, output1_type});

View File

@ -451,7 +451,7 @@ def reverse(x, axis):
:func:`mindspore.ops.reverse` will be deprecated in the future.
Please use :func:`mindspore.ops.flip` instead.
"""
return _get_cache_prim(P.array_ops.ReverseV2)(axis)(x)
return flip(x, axis)
def ravel(input):
@ -5328,7 +5328,7 @@ def topk(input, k, dim=None, largest=True, sorted=True):
Args:
input (Tensor): Input to be computed, data type must be float16, float32 or int32.
k (int): The number of top or bottom elements to be computed along the last dimension, constant input is needed.
k (int): The number of top or bottom elements to be computed along the last dimension.
dim (int, optional): The dimension to sort along. Default: ``None`` .
largest (bool, optional): If largest is ``False`` then the k smallest elements are returned.
Default: ``True`` .

View File

@ -6264,7 +6264,8 @@ class TopK(Primitive):
- GPU: float16, float32.
- CPU: all numeric types.
- **k** (int) - The number of top elements to be computed along the last dimension, constant input is needed.
- **k** (Union(Tensor, int)) - The number of top elements to be computed along the last dimension.
If `k` is a Tensor, the supported dtype is int32 and it should be 0-D or 1-D with shape :math:`(1, )` .
Outputs:
A tuple consisting of `values` and `indexes`.

View File

@ -19,6 +19,7 @@ import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import mutable, Tensor
class Net(nn.Cell):
@ -69,3 +70,48 @@ def test_topk_normal(mode):
assert np.allclose(output1.asnumpy(), expect_output1, rtol=1e-3, atol=1e-5)
assert np.allclose(output2_0.asnumpy(), expect_output2_0, rtol=1e-3, atol=1e-5)
assert np.allclose(output2_1.asnumpy(), expect_output2_1, rtol=1e-3, atol=1e-5)
@pytest.mark.level2
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
@pytest.mark.parametrize('input_k', [mutable(Tensor(2, ms.int32)), mutable(Tensor([2], ms.int32))])
def test_topk_mutable_k(mode, input_k):
"""
Feature: topk
Description: Verify the result of topk with mutable Tensor `k` as input.
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor([[0.5368, 0.2447, 0.4302, 0.9673],
[0.4388, 0.6525, 0.4685, 0.1868],
[0.3563, 0.5152, 0.9675, 0.8230]], dtype=ms.float32)
k = input_k
output = net(x, k, dim=1)
output0 = output[0]
output1 = output[1]
expect_output0 = np.array([[0.9673, 0.5368],
[0.6525, 0.4685],
[0.9675, 0.823]])
expect_output1 = np.array([[3, 0],
[1, 2],
[2, 3]])
output2 = net(x, k, dim=1, largest=False)
output2_0 = output2[0]
output2_1 = output2[1]
expect_output2_0 = np.array([[2.44700000e-01, 4.30200011e-01],
[1.86800003e-01, 4.38800007e-01],
[3.56299996e-01, 5.15200019e-01]])
expect_output2_1 = np.array([[1, 2],
[3, 0],
[0, 1]])
assert np.allclose(output0.asnumpy(), expect_output0, rtol=1e-3, atol=1e-5)
assert np.allclose(output1.asnumpy(), expect_output1, rtol=1e-3, atol=1e-5)
assert np.allclose(output2_0.asnumpy(), expect_output2_0, rtol=1e-3, atol=1e-5)
assert np.allclose(output2_1.asnumpy(), expect_output2_1, rtol=1e-3, atol=1e-5)