impl review suggestions for PR: 62636 add ST for topk: support k with mutable Tensor type mpl review suggestions for PR: 63268
This commit is contained in:
parent
772542e148
commit
9c3e963dbe
|
@ -27,7 +27,7 @@
|
|||
- GPU:float16、float32。
|
||||
- CPU:所有数值型。
|
||||
|
||||
- **k** (int) - 指定计算最大元素的数量,必须为常量。
|
||||
- **k** (Union(Tensor, int)) - 指定计算最大元素的数量。若 `k` 为Tensor,其数据类型须为int32。若为Tensor,只支持零维Tensor或shape为 :math:`(1, )` 的一维Tensor。
|
||||
|
||||
输出:
|
||||
由 `values` 和 `indices` 组成的tuple。
|
||||
|
|
|
@ -19,7 +19,7 @@ mindspore.ops.topk
|
|||
|
||||
参数:
|
||||
- **input** (Tensor) - 需计算的输入,数据类型必须为float16、float32或int32。
|
||||
- **k** (int) - 指定计算最大或最小元素的数量,必须为常量。
|
||||
- **k** (int) - 指定计算最大或最小元素的数量。
|
||||
- **dim** (int, 可选) - 需要排序的维度。默认值: ``None`` 。
|
||||
- **largest** (bool, 可选) - 如果为 ``False`` ,则会返回前k个最小值。默认值: ``True`` 。
|
||||
- **sorted** (bool, 可选) - 如果为 ``True`` ,则获取的元素将按值降序排序。如果为 ``False`` ,则不对获取的元素进行排序。默认值: ``True`` 。
|
||||
|
|
|
@ -694,7 +694,7 @@ template <typename T, typename S>
|
|||
void ReluBool(ArithmeticSelfCpuKernelFuncBool<T, S> *content, const T *in, S *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = in[i] ? true : false;
|
||||
out[i] = in[i];
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_);
|
||||
|
|
|
@ -52,6 +52,12 @@ abstract::TupleShapePtr TopKInferShape(const PrimitivePtr &primitive, const std:
|
|||
|
||||
// 2rd input is a Tensor when TopK is a dynamic shape operator
|
||||
if (CheckAndConvertUtils::IsTensor(input_args[kInputIndex1])) {
|
||||
auto k_dim = input_args[kInputIndex1]->GetShape()->GetShapeVector().size();
|
||||
if (k_dim > 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << prim_name
|
||||
<< "', the dimension of 'k' should only be 0 or 1 when 'k' is a Tensor, but got: " << k_dim
|
||||
<< ".";
|
||||
}
|
||||
auto k_val = CheckAndConvertUtils::CheckTensorIntValue("k", input_args[kInputIndex1]->GetValue(), prim_name,
|
||||
input_args[kInputIndex1]->GetType());
|
||||
k_v = k_val[0];
|
||||
|
@ -78,7 +84,7 @@ TuplePtr TopKInferType(const PrimitivePtr &primitive, const std::vector<Abstract
|
|||
auto output0_type = input_args[kInputIndex0]->GetType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", output0_type, common_valid_types, prim_name);
|
||||
auto k_type = input_args[kInputIndex1]->GetType();
|
||||
const std::set<TypePtr> int_types = {kInt8, kInt16, kInt32, kInt64};
|
||||
const std::set<TypePtr> int_types = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("k", k_type, int_types, prim_name);
|
||||
auto output1_type = kInt32;
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{output0_type, output1_type});
|
||||
|
|
|
@ -451,7 +451,7 @@ def reverse(x, axis):
|
|||
:func:`mindspore.ops.reverse` will be deprecated in the future.
|
||||
Please use :func:`mindspore.ops.flip` instead.
|
||||
"""
|
||||
return _get_cache_prim(P.array_ops.ReverseV2)(axis)(x)
|
||||
return flip(x, axis)
|
||||
|
||||
|
||||
def ravel(input):
|
||||
|
@ -5328,7 +5328,7 @@ def topk(input, k, dim=None, largest=True, sorted=True):
|
|||
|
||||
Args:
|
||||
input (Tensor): Input to be computed, data type must be float16, float32 or int32.
|
||||
k (int): The number of top or bottom elements to be computed along the last dimension, constant input is needed.
|
||||
k (int): The number of top or bottom elements to be computed along the last dimension.
|
||||
dim (int, optional): The dimension to sort along. Default: ``None`` .
|
||||
largest (bool, optional): If largest is ``False`` then the k smallest elements are returned.
|
||||
Default: ``True`` .
|
||||
|
|
|
@ -6264,7 +6264,8 @@ class TopK(Primitive):
|
|||
- GPU: float16, float32.
|
||||
- CPU: all numeric types.
|
||||
|
||||
- **k** (int) - The number of top elements to be computed along the last dimension, constant input is needed.
|
||||
- **k** (Union(Tensor, int)) - The number of top elements to be computed along the last dimension.
|
||||
If `k` is a Tensor, the supported dtype is int32 and it should be 0-D or 1-D with shape :math:`(1, )` .
|
||||
|
||||
Outputs:
|
||||
A tuple consisting of `values` and `indexes`.
|
||||
|
|
|
@ -19,6 +19,7 @@ import pytest
|
|||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import mutable, Tensor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
@ -69,3 +70,48 @@ def test_topk_normal(mode):
|
|||
assert np.allclose(output1.asnumpy(), expect_output1, rtol=1e-3, atol=1e-5)
|
||||
assert np.allclose(output2_0.asnumpy(), expect_output2_0, rtol=1e-3, atol=1e-5)
|
||||
assert np.allclose(output2_1.asnumpy(), expect_output2_1, rtol=1e-3, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('input_k', [mutable(Tensor(2, ms.int32)), mutable(Tensor([2], ms.int32))])
|
||||
def test_topk_mutable_k(mode, input_k):
|
||||
"""
|
||||
Feature: topk
|
||||
Description: Verify the result of topk with mutable Tensor `k` as input.
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor([[0.5368, 0.2447, 0.4302, 0.9673],
|
||||
[0.4388, 0.6525, 0.4685, 0.1868],
|
||||
[0.3563, 0.5152, 0.9675, 0.8230]], dtype=ms.float32)
|
||||
k = input_k
|
||||
output = net(x, k, dim=1)
|
||||
output0 = output[0]
|
||||
output1 = output[1]
|
||||
expect_output0 = np.array([[0.9673, 0.5368],
|
||||
[0.6525, 0.4685],
|
||||
[0.9675, 0.823]])
|
||||
expect_output1 = np.array([[3, 0],
|
||||
[1, 2],
|
||||
[2, 3]])
|
||||
output2 = net(x, k, dim=1, largest=False)
|
||||
output2_0 = output2[0]
|
||||
output2_1 = output2[1]
|
||||
expect_output2_0 = np.array([[2.44700000e-01, 4.30200011e-01],
|
||||
[1.86800003e-01, 4.38800007e-01],
|
||||
[3.56299996e-01, 5.15200019e-01]])
|
||||
expect_output2_1 = np.array([[1, 2],
|
||||
[3, 0],
|
||||
[0, 1]])
|
||||
assert np.allclose(output0.asnumpy(), expect_output0, rtol=1e-3, atol=1e-5)
|
||||
assert np.allclose(output1.asnumpy(), expect_output1, rtol=1e-3, atol=1e-5)
|
||||
assert np.allclose(output2_0.asnumpy(), expect_output2_0, rtol=1e-3, atol=1e-5)
|
||||
assert np.allclose(output2_1.asnumpy(), expect_output2_1, rtol=1e-3, atol=1e-5)
|
||||
|
|
Loading…
Reference in New Issue