forked from mindspore-Ecosystem/mindspore
[MS][LITE] update some ops functional API support
This commit is contained in:
parent
7a0f8a6770
commit
f0125d4d78
|
@ -60,6 +60,10 @@
|
|||
"mindspore/mindspore/python/mindspore/ops/functional.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/functional.py" "wildcard-import"
|
||||
"mindspore/mindspore/python/mindspore/ops/functional.py" "unused-wildcard-import"
|
||||
"mindspore/mindspore/python/mindspore/_extends/parse/standard_method.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/common/tensor.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/function/array_func.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations/array_ops.py" "redefined-builtin"
|
||||
|
||||
# MindData
|
||||
"mindspore/mindspore/python/mindspore/dataset/__init__.py" "redefined-builtin"
|
||||
|
|
|
@ -413,6 +413,7 @@ Array操作
|
|||
mindspore.ops.tensor_scatter_sub
|
||||
mindspore.ops.tensor_scatter_elements
|
||||
mindspore.ops.tile
|
||||
mindspore.ops.top_k
|
||||
mindspore.ops.transpose
|
||||
mindspore.ops.unique
|
||||
mindspore.ops.unique_consecutive
|
||||
|
|
|
@ -1757,6 +1757,41 @@ mindspore.Tensor
|
|||
- **TypeError** - `indices` 的数据类型既不是int32,也不是int64。
|
||||
- **ValueError** - Tensor的shape长度小于 `indices` 的shape的最后一个维度。
|
||||
|
||||
.. py:method:: top_k(k, sorted=True)
|
||||
|
||||
沿最后一个维度查找 `k` 个最大元素和对应的索引。
|
||||
|
||||
.. warning::
|
||||
- 如果 `sorted` 设置为'False',它将使用aicpu运算符,性能可能会降低。
|
||||
|
||||
`x` 指的当前 Tensor。
|
||||
|
||||
如果 `input_x` 是一维Tensor,则查找Tensor中 `k` 个最大元素,并将其值和索引输出为Tensor。因此, `values[k]` 是 `input_x` 中 `k` 个最大元素,其索引是 `indices[k]` 。
|
||||
|
||||
对于多维矩阵,计算每行中最大的 `k` 个元素(沿最后一个维度的相应向量),因此:
|
||||
|
||||
.. math::
|
||||
values.shape = indices.shape = input\_x.shape[:-1] + [k].
|
||||
|
||||
如果两个比较的元素相同,则优先返回索引值较小的元素。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **k** (int) - 指定计算最大元素的数量,需要是常量。
|
||||
- **sorted** (bool, optional) - 如果为True,则获取的元素将按值降序排序。默认值:True。
|
||||
|
||||
**返回:**
|
||||
|
||||
2个Tensor组成的tuple, `values` 和 `indices` 。
|
||||
|
||||
- **values** (Tensor) - 最后一个维度的每个切片中的 `k` 最大元素。
|
||||
- **indices** (Tensor) - `k` 最大元素的对应索引。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - 如果 `k` 不是int。
|
||||
- **TypeError** - 如果 `sorted` 不是bool。
|
||||
|
||||
.. py:method:: scatter_max(indices, updates)
|
||||
|
||||
根据指定的更新值和输入索引,通过最大值运算,输出结果以Tensor形式返回。
|
||||
|
|
|
@ -5,33 +5,4 @@
|
|||
|
||||
沿最后一个维度查找 `k` 个最大元素和对应的索引。
|
||||
|
||||
.. warning::
|
||||
- 如果 `sorted` 设置为'False',它将使用aicpu运算符,性能可能会降低。
|
||||
|
||||
如果 `input_x` 是一维Tensor,则查找Tensor中 `k` 个最大元素,并将其值和索引输出为Tensor。因此, `values[k]` 是 `input_x` 中 `k` 个最大元素,其索引是 `indices[k]` 。
|
||||
|
||||
对于多维矩阵,计算每行中最大的 `k` 个元素(沿最后一个维度的相应向量),因此:
|
||||
|
||||
.. math::
|
||||
values.shape = indices.shape = input.shape[:-1] + [k].
|
||||
|
||||
如果两个比较的元素相同,则优先返回索引值较小的元素。
|
||||
|
||||
参数:
|
||||
- **sorted** (bool) - 如果为True,则获取的元素将按值降序排序。默认值:True。
|
||||
|
||||
输入:
|
||||
- **input_x** (Tensor) - 需计算的输入,数据类型必须为float16、float32或int32。
|
||||
- **k** (int) - 指定计算最大元素的数量,需要是常量。
|
||||
|
||||
输出:
|
||||
2个Tensor组成的tuple, `values` 和 `indices` 。
|
||||
|
||||
- **values** (Tensor) - 最后一个维度的每个切片中的 `k` 最大元素。
|
||||
- **indices** (Tensor) - `k` 最大元素的对应索引。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `sorted` 不是bool。
|
||||
- **TypeError** - 如果 `input_x` 不是Tensor。
|
||||
- **TypeError** - 如果 `k` 不是int。
|
||||
- **TypeError** - 如果 `input_x` 的数据类型不是以下之一:float16、float32或int32。
|
||||
更多参考详见 :func:`mindspore.ops.top_k`。
|
|
@ -5,26 +5,4 @@
|
|||
|
||||
沿分段计算输入Tensor元素的和。
|
||||
|
||||
计算输出Tensor :math:`\text{output}[i] = \sum_{segment\_ids[j] == i} \text{data}[j, \ldots]` ,其中 :math:`j,...` 是代表元素索引的Tuple。 `segment_ids` 确定输入Tensor元素的分段。 `segment_ids` 不需要排序,也不需要覆盖 `num_segments` 范围内的所有值。
|
||||
|
||||
UnsortedSegmentSum的计算过程如下图所示:
|
||||
|
||||
.. image:: UnsortedSegmentSum.png
|
||||
|
||||
.. note::
|
||||
- 如果 `segment_ids` 中不存在segment_id `i` ,则对输出 `output[i]` 填充0。
|
||||
- 在Ascend平台上,如果segment_id的值小于0或大于输入Tensor的shape的长度,将触发执行错误。
|
||||
|
||||
如果 `segment_ids` 元素为负数,将忽略该值。 `num_segments` 必须等于不同segment_id的数量。
|
||||
|
||||
输入:
|
||||
- **input_x** (Tensor) - shape: :math:`(x_1, x_2, ..., x_R)` 。
|
||||
- **segment_ids** (Tensor) - shape为 :math:`(x_1)` 的1维张量,值必须是非负数。数据类型支持int32。
|
||||
- **num_segments** (int) - 分段数量 :math:`z` 。
|
||||
|
||||
输出:
|
||||
Tensor,shape: :math:`(z, x_{N+1}, ..., x_R)` 。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `num_segments` 不是int类型。
|
||||
- **ValueError** - `segment_ids` 的维度小于1。
|
||||
更多参考详见 :func:`mindspore.ops.unsorted_segment_sum`。
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
mindspore.ops.top_k
|
||||
===================
|
||||
|
||||
.. py:function:: mindspore.ops.top_k(input_x, k, sorted=True)
|
||||
|
||||
沿最后一个维度查找 `k` 个最大元素和对应的索引。
|
||||
|
||||
.. warning::
|
||||
- 如果 `sorted` 设置为'False',它将使用aicpu运算符,性能可能会降低。
|
||||
|
||||
如果 `input_x` 是一维Tensor,则查找Tensor中 `k` 个最大元素,并将其值和索引输出为Tensor。因此, `values[k]` 是 `input_x` 中 `k` 个最大元素,其索引是 `indices[k]` 。
|
||||
|
||||
对于多维矩阵,计算每行中最大的 `k` 个元素(沿最后一个维度的相应向量),因此:
|
||||
|
||||
.. math::
|
||||
values.shape = indices.shape = input.shape[:-1] + [k].
|
||||
|
||||
如果两个比较的元素相同,则优先返回索引值较小的元素。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **input_x** (Tensor) - 需计算的输入,数据类型必须为float16、float32或int32。
|
||||
- **k** (int) - 指定计算最大元素的数量,需要是常量。
|
||||
- **sorted** (bool, optional) - 如果为True,则获取的元素将按值降序排序。默认值:True。
|
||||
|
||||
**返回:**
|
||||
|
||||
2个Tensor组成的tuple, `values` 和 `indices` 。
|
||||
|
||||
- **values** (Tensor) - 最后一个维度的每个切片中的 `k` 最大元素。
|
||||
- **indices** (Tensor) - `k` 最大元素的对应索引。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - 如果 `sorted` 不是bool。
|
||||
- **TypeError** - 如果 `input_x` 不是Tensor。
|
||||
- **TypeError** - 如果 `k` 不是int。
|
||||
- **TypeError** - 如果 `input_x` 的数据类型不是以下之一:float16、float32或int32。
|
|
@ -416,6 +416,7 @@ Array Operation
|
|||
mindspore.ops.tensor_scatter_sub
|
||||
mindspore.ops.tensor_scatter_elements
|
||||
mindspore.ops.tile
|
||||
mindspore.ops.top_k
|
||||
mindspore.ops.transpose
|
||||
mindspore.ops.unique
|
||||
mindspore.ops.unique_consecutive
|
||||
|
|
|
@ -270,6 +270,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"erf", std::string("erf")}, // P.Erf()
|
||||
{"erfc", std::string("erfc")}, // P.Erfc()
|
||||
{"arg_min_with_value", std::string("arg_min_with_value")}, // P.ArgMinWithValue
|
||||
{"top_k", std::string("top_k")}, // P.TopK()
|
||||
}},
|
||||
{kObjectTypeRowTensorType,
|
||||
{
|
||||
|
|
|
@ -2159,6 +2159,7 @@ check_is_int = constexpr(validator.check_is_int)
|
|||
check_type_name = constexpr(validator.check_type_name)
|
||||
check_value_type = constexpr(validator.check_value_type)
|
||||
check_int = constexpr(validator.check_int)
|
||||
check_bool = constexpr(validator.check_bool)
|
||||
|
||||
|
||||
def tensor_bool(x):
|
||||
|
@ -2283,6 +2284,15 @@ def ceil(x):
|
|||
return F.ceil(x)
|
||||
|
||||
|
||||
def top_k(input_x, k, sorted=True):
|
||||
"""
|
||||
Finds values and indices of the `k` largest entries along the last dimension.
|
||||
"""
|
||||
check_is_int(k, 'k')
|
||||
check_bool(sorted, 'sorted')
|
||||
return F.top_k(input_x, k, sorted)
|
||||
|
||||
|
||||
#############
|
||||
# Iteration #
|
||||
#############
|
||||
|
|
|
@ -4580,6 +4580,62 @@ class Tensor(Tensor_):
|
|||
return tensor_operator_registry.get("erfc")()(self)
|
||||
|
||||
|
||||
def top_k(self, k, sorted=True):
|
||||
r"""
|
||||
Finds values and indices of the `k` largest entries along the last dimension.
|
||||
|
||||
.. warning::
|
||||
- If sorted is set to 'False', it will use the aicpu operator, the performance may be reduced.
|
||||
|
||||
`input_x` refer to self tensor.
|
||||
|
||||
If the `input_x` is a one-dimensional Tensor, finds the `k` largest entries in the Tensor,
|
||||
and outputs its value and index as a Tensor. Therefore, values[`k`] is the `k` largest item in `input_x`,
|
||||
and its index is indices [`k`].
|
||||
|
||||
For a multi-dimensional matrix,
|
||||
calculates the first `k` entries in each row (corresponding vector along the last dimension), therefore:
|
||||
|
||||
.. math::
|
||||
|
||||
values.shape = indices.shape = input\_x.shape[:-1] + [k].
|
||||
|
||||
If the two compared elements are the same, the one with the smaller index value is returned first.
|
||||
|
||||
Args:
|
||||
k (int): The number of top elements to be computed along the last dimension, constant input is needed.
|
||||
sorted (bool, optional): If true, the obtained elements will be sorted by the values in descending order.
|
||||
Default: True.
|
||||
|
||||
Returns:
|
||||
Tuple of 2 tensors, the values and the indices.
|
||||
|
||||
- values (Tensor): The `k` largest elements in each slice of the last dimension.
|
||||
- indices (Tensor): The indices of values within the last dimension of input.
|
||||
|
||||
Raises:
|
||||
TypeError: If `k` is not an int.
|
||||
TypeError: If `sorted` is not a bool.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore import Tensor
|
||||
>>> input_x = Tensor([1, 2, 3, 4, 5], ms.float16)
|
||||
>>> k = 3
|
||||
>>> values, indices = input_x.top_k(k, sorted=True)
|
||||
>>> print((values, indices))
|
||||
(Tensor(shape=[3], dtype=Float16, value= [ 5.0000e+00, 4.0000e+00, 3.0000e+00]), Tensor(shape=[3],
|
||||
dtype=Int32, value= [4, 3, 2]))
|
||||
"""
|
||||
self._init_check()
|
||||
validator.check_is_int(k, 'k')
|
||||
validator.check_bool(sorted, 'sorted')
|
||||
return tensor_operator_registry.get("top_k")(sorted)(self, k)
|
||||
|
||||
|
||||
class RowTensor(RowTensor_):
|
||||
"""
|
||||
A sparse representation of a set of tensor slices at given indices.
|
||||
|
|
|
@ -109,6 +109,7 @@ from .array_func import (
|
|||
max,
|
||||
min,
|
||||
population_count,
|
||||
top_k,
|
||||
)
|
||||
from .parameter_func import (
|
||||
assign,
|
||||
|
|
|
@ -4032,6 +4032,62 @@ def unsorted_segment_sum(input_x, segment_ids, num_segments):
|
|||
return unsorted_segment_sum_(input_x, segment_ids, num_segments)
|
||||
|
||||
|
||||
def top_k(input_x, k, sorted=True):
|
||||
r"""
|
||||
Finds values and indices of the `k` largest entries along the last dimension.
|
||||
|
||||
.. warning::
|
||||
- If sorted is set to 'False', it will use the aicpu operator, the performance may be reduced.
|
||||
|
||||
If the `input_x` is a one-dimensional Tensor, finds the `k` largest entries in the Tensor,
|
||||
and outputs its value and index as a Tensor. Therefore, values[`k`] is the `k` largest item in `input_x`,
|
||||
and its index is indices [`k`].
|
||||
|
||||
For a multi-dimensional matrix,
|
||||
calculates the first `k` entries in each row (corresponding vector along the last dimension), therefore:
|
||||
|
||||
.. math::
|
||||
|
||||
values.shape = indices.shape = input.shape[:-1] + [k].
|
||||
|
||||
If the two compared elements are the same, the one with the smaller index value is returned first.
|
||||
|
||||
Args:
|
||||
input_x (Tensor): Input to be computed, data type must be float16, float32 or int32.
|
||||
k (int): The number of top elements to be computed along the last dimension, constant input is needed.
|
||||
sorted (bool, optional): If true, the obtained elements will be sorted by the values in descending order.
|
||||
Default: True.
|
||||
|
||||
Returns:
|
||||
Tuple of 2 tensors, the values and the indices.
|
||||
|
||||
- values (Tensor): The `k` largest elements in each slice of the last dimension.
|
||||
- indices (Tensor): The indices of values within the last dimension of input.
|
||||
|
||||
Raises:
|
||||
TypeError: If `sorted` is not a bool.
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
TypeError: If `k` is not an int.
|
||||
TypeError: If dtype of `input_x` is not one of the following: float16, float32 or int32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.ops as ops
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore import Tensor
|
||||
>>> input_x = Tensor([1, 2, 3, 4, 5], ms.float16)
|
||||
>>> k = 3
|
||||
>>> values, indices = ops.top_k(input_x, k, sorted=True)
|
||||
>>> print((values, indices))
|
||||
(Tensor(shape=[3], dtype=Float16, value= [ 5.0000e+00, 4.0000e+00, 3.0000e+00]), Tensor(shape=[3],
|
||||
dtype=Int32, value= [4, 3, 2]))
|
||||
"""
|
||||
top_k_ = _get_cache_prim(P.TopK)(sorted)
|
||||
return top_k_(input_x, k)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'unique',
|
||||
'unique_with_pad',
|
||||
|
@ -4116,5 +4172,6 @@ __all__ = [
|
|||
'min',
|
||||
'unsorted_segment_sum',
|
||||
'population_count',
|
||||
'top_k',
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -889,7 +889,7 @@ tensor_operator_registry.register('renorm', renorm)
|
|||
tensor_operator_registry.register('adaptive_max_pool2d', AdaptiveMaxPool2D)
|
||||
tensor_operator_registry.register('coalesce', coalesce)
|
||||
tensor_operator_registry.register('arg_min_with_value', min)
|
||||
tensor_operator_registry.register('unsorted_segment_sum', P.UnsortedSegmentSum)
|
||||
tensor_operator_registry.register('coo_add', sparse_add)
|
||||
tensor_operator_registry.register('top_k', P.TopK)
|
||||
__all__ = [name for name in dir() if name[0] != "_"]
|
||||
__all__.remove('Primitive')
|
||||
|
|
|
@ -47,7 +47,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
|
|||
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
|
||||
TensorScatterUpdate, TensorScatterMax, TensorScatterMin, TensorScatterAdd, TensorScatterSub,
|
||||
TensorScatterMul, TensorScatterDiv, ExtractVolumePatches, LowerBound,
|
||||
UpperBound, Cummax, Mvlgamma, PopulationCount)
|
||||
UpperBound, Cummax, Mvlgamma, PopulationCount, TopK)
|
||||
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter,
|
||||
Broadcast,
|
||||
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
||||
|
@ -92,7 +92,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
|
|||
SmoothL1Loss, SoftMarginLoss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
|
||||
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
||||
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
||||
TopK, BinaryCrossEntropy, KLDivLoss, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
|
||||
BinaryCrossEntropy, KLDivLoss, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
|
||||
ApplyProximalAdagrad, SparseApplyProximalAdagrad, SparseApplyAdagradV2, SparseApplyFtrlV2,
|
||||
FusedSparseFtrl, FusedSparseProximalAdagrad, SparseApplyRMSProp, SparseApplyAdadelta,
|
||||
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyAdagradDA,
|
||||
|
|
|
@ -7973,3 +7973,34 @@ class PopulationCount(Primitive):
|
|||
def __init__(self):
|
||||
"""Initialize PopulationCount"""
|
||||
self.init_prim_io_names(inputs=['input'], outputs=['output'])
|
||||
|
||||
|
||||
class TopK(Primitive):
|
||||
"""
|
||||
Finds values and indices of the `k` largest entries along the last dimension.
|
||||
|
||||
Refer to :func:`mindspore.ops.top_k` for more detail.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops.operations.nn_ops import TopK
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore import Tensor
|
||||
>>> top_k = TopK(sorted=True)
|
||||
>>> input_x = Tensor([1, 2, 3, 4, 5], ms.float16)
|
||||
>>> k = 3
|
||||
>>> values, indices = top_k(input_x, k)
|
||||
>>> print((values, indices))
|
||||
(Tensor(shape=[3], dtype=Float16, value= [ 5.0000e+00, 4.0000e+00, 3.0000e+00]), Tensor(shape=[3],
|
||||
dtype=Int32, value= [4, 3, 2]))
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, sorted=True):
|
||||
"""Initialize TopK."""
|
||||
self.sorted = validator.check_value_type("sorted", sorted, [bool], self.name)
|
||||
self.add_prim_attr("sorted", self.sorted)
|
||||
self.init_prim_io_names(inputs=['input', 'k'],
|
||||
outputs=['values', 'indices'])
|
||||
|
|
|
@ -2673,68 +2673,6 @@ class BiasAdd(Primitive):
|
|||
self.add_prim_attr('data_format', self.format)
|
||||
|
||||
|
||||
class TopK(Primitive):
|
||||
"""
|
||||
Finds values and indices of the `k` largest entries along the last dimension.
|
||||
|
||||
.. warning::
|
||||
- If sorted is set to 'False', it will use the aicpu operator, the performance may be reduced.
|
||||
|
||||
If the `input_x` is a one-dimensional Tensor, finds the `k` largest entries in the Tensor,
|
||||
and outputs its value and index as a Tensor. Therefore, values[`k`] is the `k` largest item in `input_x`,
|
||||
and its index is indices [`k`].
|
||||
|
||||
For a multi-dimensional matrix,
|
||||
calculates the first `k` entries in each row (corresponding vector along the last dimension), therefore:
|
||||
|
||||
.. math::
|
||||
|
||||
values.shape = indices.shape = input.shape[:-1] + [k].
|
||||
|
||||
If the two compared elements are the same, the one with the smaller index value is returned first.
|
||||
|
||||
Args:
|
||||
sorted (bool): If true, the obtained elements will
|
||||
be sorted by the values in descending order. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - Input to be computed, data type must be float16, float32 or int32.
|
||||
- **k** (int) - The number of top elements to be computed along the last dimension, constant input is needed.
|
||||
|
||||
Outputs:
|
||||
Tuple of 2 tensors, the values and the indices.
|
||||
|
||||
- **values** (Tensor) - The `k` largest elements in each slice of the last dimension.
|
||||
- **indices** (Tensor) - The indices of values within the last dimension of input.
|
||||
|
||||
Raises:
|
||||
TypeError: If `sorted` is not a bool.
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
TypeError: If `k` is not an int.
|
||||
TypeError: If dtype of `input_x` is not one of the following: float16, float32 or int32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> topk = ops.TopK(sorted=True)
|
||||
>>> input_x = Tensor([1, 2, 3, 4, 5], mindspore.float16)
|
||||
>>> k = 3
|
||||
>>> values, indices = topk(input_x, k)
|
||||
>>> print((values, indices))
|
||||
(Tensor(shape=[3], dtype=Float16, value= [ 5.0000e+00, 4.0000e+00, 3.0000e+00]), Tensor(shape=[3],
|
||||
dtype=Int32, value= [4, 3, 2]))
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, sorted=True):
|
||||
"""Initialize TopK."""
|
||||
self.sorted = validator.check_value_type("sorted", sorted, [bool], self.name)
|
||||
self.add_prim_attr("sorted", self.sorted)
|
||||
self.init_prim_io_names(inputs=['input', 'k'],
|
||||
outputs=['values', 'indices'])
|
||||
|
||||
|
||||
class NLLLoss(PrimitiveWithInfer):
|
||||
r"""
|
||||
Gets the negative log likelihood loss between logits and labels.
|
||||
|
|
|
@ -86,7 +86,7 @@ def unsorted_segment_arith_expected(func, x, segment_ids, num_segments):
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('func', ['min', 'max'])
|
||||
@pytest.mark.parametrize('func', ['min', 'max', 'sum'])
|
||||
def test_unsorted_segment_op(func):
|
||||
"""
|
||||
Feature: test_unsorted_segment_op* operators.
|
||||
|
@ -103,6 +103,8 @@ def test_unsorted_segment_op(func):
|
|||
graph_output = P.UnsortedSegmentMin()(x, segment_ids, num_segments)
|
||||
if func == 'max':
|
||||
graph_output = P.UnsortedSegmentMax()(x, segment_ids, num_segments)
|
||||
if func == 'sum':
|
||||
graph_output = P.UnsortedSegmentSum()(x, segment_ids, num_segments)
|
||||
|
||||
expected = unsorted_segment_arith_expected(func, x, segment_ids, num_segments)
|
||||
np.testing.assert_array_almost_equal(graph_output.asnumpy(), expected)
|
||||
|
@ -116,6 +118,8 @@ class TestUnsortedSegmentArithmeticNet(nn.Cell):
|
|||
self.func = P.UnsortedSegmentMin()
|
||||
if func == 'max':
|
||||
self.func = P.UnsortedSegmentMax()
|
||||
if func == 'sum':
|
||||
self.func = P.UnsortedSegmentSum()
|
||||
self.num_segments = num_segments
|
||||
|
||||
def construct(self, x, segment_ids):
|
||||
|
@ -126,7 +130,7 @@ class TestUnsortedSegmentArithmeticNet(nn.Cell):
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('func', ['min', 'max'])
|
||||
@pytest.mark.parametrize('func', ['min', 'max', 'sum'])
|
||||
def test_unsorted_segment_op_dynamic_shape(func):
|
||||
"""
|
||||
Feature: test_unsorted_segment_op_dynamic_shape.
|
||||
|
|
|
@ -21,6 +21,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
@ -69,6 +70,7 @@ def test_slice2():
|
|||
output = slice_op(x)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
def test_slice_float64():
|
||||
data = Tensor(np.array([[[1, 1, 1], [2, 2, 2]],
|
||||
[[3, 3, 3], [4, 4, 4]],
|
||||
|
@ -78,6 +80,7 @@ def test_slice_float64():
|
|||
expect = [[[3.0, 3.0, 3.0]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
class Slice3(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Slice3, self).__init__()
|
||||
|
@ -179,6 +182,7 @@ class StridedSlice(nn.Cell):
|
|||
def construct(self, x):
|
||||
return self.stride_slice(x, self.begin, self.end, self.stride)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -193,6 +197,25 @@ def test_strided_slice_bool_type():
|
|||
expected_output = np.array([False, True, False])
|
||||
assert (output.asnumpy() == expected_output).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_slice_functional():
|
||||
"""
|
||||
Feature: test_slice_functional
|
||||
Description: test slice functional API
|
||||
Expectation: the output is as expected
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
x = Tensor(
|
||||
np.array([[[1, -1, 1], [2, -2, 2]], [[3, -3, 3], [4, -4, 4]], [[5, -5, 5], [6, -6, 6]]]), mstype.float32)
|
||||
expect = [[[2., -2., 2.]],
|
||||
[[4., -4., 4.]]]
|
||||
output = F.slice(x, begin=(0, 1, 0), size=(2, 1, 3))
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_slice()
|
||||
test_slice2()
|
||||
|
@ -201,3 +224,4 @@ if __name__ == '__main__':
|
|||
test_slice5()
|
||||
test_slice6()
|
||||
test_strided_slice_bool_type()
|
||||
test_slice_functional()
|
||||
|
|
|
@ -19,6 +19,7 @@ import pytest
|
|||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -80,3 +81,71 @@ def test_topk():
|
|||
k = 40960
|
||||
ms_output = P.TopK(False)(Tensor(x_np), k)
|
||||
assert np.allclose(ms_output[0].asnumpy(), x_np)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_top_k_functional():
|
||||
"""
|
||||
Feature: test_top_k_functional
|
||||
Description: test top_k functional API
|
||||
Expectation: the output is as expected
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
x_np = np.random.rand(3, 4).astype(np.float32)
|
||||
k = 4
|
||||
ms_output = F.top_k(Tensor(x_np), k, True)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
x_np = np.random.rand(3, 4).astype(np.float32)
|
||||
k = 4
|
||||
ms_output = F.top_k(Tensor(x_np), k, False)
|
||||
assert np.allclose(ms_output[0].asnumpy(), x_np)
|
||||
|
||||
x_np = np.random.rand(2, 3, 4).astype(np.float32)
|
||||
k = 2
|
||||
ms_output = F.top_k(Tensor(x_np), k, True)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
x_np = np.random.rand(512, 1024).astype(np.float32)
|
||||
k = 512
|
||||
ms_output = F.top_k(Tensor(x_np), k, True)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_top_k_tensor():
|
||||
"""
|
||||
Feature: test_top_k_tensor
|
||||
Description: test top_k tensor API
|
||||
Expectation: the output is as expected
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
x_np = np.random.rand(3, 4).astype(np.float32)
|
||||
k = 4
|
||||
ms_output = Tensor(x_np).top_k(k, True)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
x_np = np.random.rand(3, 4).astype(np.float32)
|
||||
k = 4
|
||||
ms_output = Tensor(x_np).top_k(k, False)
|
||||
assert np.allclose(ms_output[0].asnumpy(), x_np)
|
||||
|
||||
x_np = np.random.rand(2, 3, 4).astype(np.float32)
|
||||
k = 2
|
||||
ms_output = Tensor(x_np).top_k(k, True)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
x_np = np.random.rand(512, 1024).astype(np.float32)
|
||||
k = 512
|
||||
ms_output = Tensor(x_np).top_k(k, True)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
@ -31,7 +32,6 @@ UnsortedSegmentArith_func_map = {
|
|||
"prod": ops.UnsortedSegmentProd,
|
||||
}
|
||||
|
||||
|
||||
arith_np_func_map = {
|
||||
"prod": lambda a, b: a * b,
|
||||
"sum": lambda a, b: a + b,
|
||||
|
@ -90,7 +90,6 @@ def unsorted_segment_arith_expected(func, x, segment_ids, num_segments):
|
|||
trans_inp = np_inp.reshape(trans_inp_shape)
|
||||
trans_ids = np_ids.reshape(ids_size)
|
||||
|
||||
|
||||
for i in range(ids_size):
|
||||
out_index = trans_ids[i]
|
||||
if out_index < 0:
|
||||
|
@ -105,7 +104,7 @@ def unsorted_segment_arith_expected(func, x, segment_ids, num_segments):
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('func', ['min', 'max'])
|
||||
@pytest.mark.parametrize('func', ['min', 'max', 'sum'])
|
||||
@pytest.mark.parametrize('data_type', [mstype.float32, mstype.int32])
|
||||
@pytest.mark.parametrize('index_type', [mstype.int32])
|
||||
def test_unsorted_segment_arithmetic_one_d(func, data_type, index_type):
|
||||
|
@ -188,6 +187,8 @@ def test_tensor_check(func):
|
|||
output_ms = x.unsorted_segment_min(segment_ids, num_segments)
|
||||
if func == 'max':
|
||||
output_ms = x.unsorted_segment_max(segment_ids, num_segments)
|
||||
if func == 'sum':
|
||||
output_ms = x.unsorted_segment_sum(segment_ids, num_segments)
|
||||
|
||||
expected = unsorted_segment_arith_expected(func, x, segment_ids, num_segments)
|
||||
np.testing.assert_array_almost_equal(output_ms.asnumpy(), expected)
|
||||
|
@ -196,7 +197,7 @@ def test_tensor_check(func):
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('func', ['min', 'max'])
|
||||
@pytest.mark.parametrize('func', ['min', 'max', 'sum'])
|
||||
def test_functional_check(func):
|
||||
"""
|
||||
Feature: test_functional_check.
|
||||
|
@ -212,6 +213,8 @@ def test_functional_check(func):
|
|||
output_ms = F.unsorted_segment_min(x, segment_ids, num_segments)
|
||||
if func == 'max':
|
||||
output_ms = F.unsorted_segment_max(x, segment_ids, num_segments)
|
||||
if func == 'sum':
|
||||
output_ms = F.unsorted_segment_sum(x, segment_ids, num_segments)
|
||||
|
||||
expected = unsorted_segment_arith_expected(func, x, segment_ids, num_segments)
|
||||
np.testing.assert_array_almost_equal(output_ms.asnumpy(), expected)
|
||||
|
|
|
@ -1,105 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
class UnsortedSegmentSumNet(nn.Cell):
|
||||
def __init__(self, num_segments):
|
||||
super(UnsortedSegmentSumNet, self).__init__()
|
||||
self.unsorted_segment_sum = P.UnsortedSegmentSum()
|
||||
self.num_segments = num_segments
|
||||
|
||||
def construct(self, data, ids):
|
||||
return self.unsorted_segment_sum(data, ids, self.num_segments)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_1D():
|
||||
input_x = Tensor([1, 2, 3, 4], mstype.float32)
|
||||
segment_ids = Tensor([0, 0, 1, 2], mstype.int32)
|
||||
num_segments = 4
|
||||
|
||||
net = UnsortedSegmentSumNet(num_segments)
|
||||
output = net(input_x, segment_ids)
|
||||
expect = [3, 3, 4, 0]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_2D():
|
||||
input_x = Tensor([[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12]], mstype.float32)
|
||||
segment_ids = Tensor([2, 1, 1], mstype.int32)
|
||||
num_segments = 4
|
||||
|
||||
net = UnsortedSegmentSumNet(num_segments)
|
||||
output = net(input_x, segment_ids)
|
||||
expect = [[0, 0, 0, 0],
|
||||
[14, 16, 18, 20],
|
||||
[1, 2, 3, 4],
|
||||
[0, 0, 0, 0]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_3D():
|
||||
input_x = Tensor(np.arange(4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3))
|
||||
segment_ids = Tensor([2, 1, 1, -1], mstype.int32)
|
||||
num_segments = 5
|
||||
|
||||
net = UnsortedSegmentSumNet(num_segments)
|
||||
output = net(input_x, segment_ids)
|
||||
expect = [[[0., 0., 0.],
|
||||
[0., 0., 0.],
|
||||
[0., 0., 0.],
|
||||
[0., 0., 0.],
|
||||
[0., 0., 0.]],
|
||||
[[45., 47., 49.],
|
||||
[51., 53., 55.],
|
||||
[57., 59., 61.],
|
||||
[63., 65., 67.],
|
||||
[69., 71., 73.]],
|
||||
[[0., 1., 2.],
|
||||
[3., 4., 5.],
|
||||
[6., 7., 8.],
|
||||
[9., 10., 11.],
|
||||
[12., 13., 14.]],
|
||||
[[0., 0., 0.],
|
||||
[0., 0., 0.],
|
||||
[0., 0., 0.],
|
||||
[0., 0., 0.],
|
||||
[0., 0., 0.]],
|
||||
[[0., 0., 0.],
|
||||
[0., 0., 0.],
|
||||
[0., 0., 0.],
|
||||
[0., 0., 0.],
|
||||
[0., 0., 0.]]]
|
||||
assert (output.asnumpy() == expect).all()
|
Loading…
Reference in New Issue