forked from mindspore-Ecosystem/mindspore
!47641 添加 ops.index_select/Tensor.index_select, ops.lt/Tensor.lt
Merge pull request !47641 from DavidFFFan/api_ops
This commit is contained in:
commit
cb5f0c0617
|
@ -343,6 +343,7 @@ Reduction函数
|
|||
mindspore.ops.isreal
|
||||
mindspore.ops.le
|
||||
mindspore.ops.less
|
||||
mindspore.ops.lt
|
||||
mindspore.ops.maximum
|
||||
mindspore.ops.minimum
|
||||
mindspore.ops.ne
|
||||
|
@ -452,6 +453,7 @@ Array操作
|
|||
mindspore.ops.hsplit
|
||||
mindspore.ops.index_add
|
||||
mindspore.ops.index_fill
|
||||
mindspore.ops.index_select
|
||||
mindspore.ops.inplace_add
|
||||
mindspore.ops.inplace_sub
|
||||
mindspore.ops.inplace_update
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
mindspore.Tensor.index_fill
|
||||
===========================
|
||||
|
||||
.. py:method:: mindspore.Tensor.index_fill(dim, index, value)
|
||||
.. py:method:: mindspore.Tensor.index_fill(axis, index, value)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.index_fill`。
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.index_select
|
||||
=============================
|
||||
|
||||
.. py:method:: mindspore.Tensor.index_select(axis, index)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.index_select`。
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.lt
|
||||
===================
|
||||
|
||||
.. py:method:: mindspore.Tensor.lt(other)
|
||||
|
||||
:func:`mindspore.Tensor.less` 的别名。
|
|
@ -138,6 +138,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.igammac
|
||||
mindspore.Tensor.index_add
|
||||
mindspore.Tensor.index_fill
|
||||
mindspore.Tensor.index_select
|
||||
mindspore.Tensor.init_data
|
||||
mindspore.Tensor.inner
|
||||
mindspore.Tensor.inplace_update
|
||||
|
@ -180,6 +181,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.logit
|
||||
mindspore.Tensor.logsumexp
|
||||
mindspore.Tensor.long
|
||||
mindspore.Tensor.lt
|
||||
mindspore.Tensor.masked_fill
|
||||
mindspore.Tensor.masked_select
|
||||
mindspore.Tensor.matrix_determinant
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
mindspore.ops.index_fill
|
||||
========================
|
||||
|
||||
.. py:function:: mindspore.ops.index_fill(x, dim, index, value)
|
||||
.. py:function:: mindspore.ops.index_fill(x, axis, index, value)
|
||||
|
||||
按 `index` 中给定的顺序选择索引,将输入 `value` 值填充到输入Tensor `x` 的所有 `dim` 维元素。
|
||||
按 `index` 中给定的顺序选择索引,将输入 `value` 值填充到输入Tensor `x` 的所有 `axis` 维元素。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 输入Tensor,支持的数据类型是数值型和布尔型。
|
||||
- **dim** (Union[int, Tensor]) - 填充输入Tensor的维度,要求是一个int或者数据类型为int32或int64的零维Tensor。
|
||||
- **axis** (Union[int, Tensor]) - 填充输入Tensor的维度,要求是一个int或者数据类型为int32或int64的零维Tensor。
|
||||
- **index** (Tensor) - 填充输入Tensor的索引,数据类型为int32。
|
||||
- **value** (Union[bool, int, float, Tensor]) - 填充输入Tensor的值。如果 `value` 是Tensor,那么 `value` 要求是数据类型与 `x` 相同的零维Tensor。否则,该值会自动转化为一个数据类型与 `x` 相同的零维Tensor。
|
||||
|
||||
|
@ -16,14 +16,14 @@ mindspore.ops.index_fill
|
|||
|
||||
异常:
|
||||
- **TypeError** - `x` 的类型不是Tensor。
|
||||
- **TypeError** - `dim` 的类型不是int或者Tensor。
|
||||
- **TypeError** - 当 `dim` 是Tensor时, `dim` 的数据类型不是int32或者int64。
|
||||
- **TypeError** - `axis` 的类型不是int或者Tensor。
|
||||
- **TypeError** - 当 `axis` 是Tensor时, `axis` 的数据类型不是int32或者int64。
|
||||
- **TypeError** - `index` 的类型不是Tensor。
|
||||
- **TypeError** - `index` 的数据类型不是int32。
|
||||
- **TypeError** - `value` 的类型不是bool、int、float或者Tensor。
|
||||
- **TypeError** - 当 `value` 是Tensor时, `value` 的数据类型和 `x` 的数据类型不相同。
|
||||
- **ValueError** - 当 `dim` 是Tensor时, `dim` 的维度不等于0。
|
||||
- **ValueError** - 当 `axis` 是Tensor时, `axis` 的维度不等于0。
|
||||
- **ValueError** - `index` 的维度大于1。
|
||||
- **ValueError** - 当 `value` 是Tensor时, `value` 的维度不等于0。
|
||||
- **RuntimeError** - `dim` 值超出范围[-x.ndim, x.ndim - 1]。
|
||||
- **RuntimeError** - `index` 存在值超出范围[-x.shape[dim], x.shape[dim]-1]。
|
||||
- **RuntimeError** - `axis` 值超出范围[-x.ndim, x.ndim - 1]。
|
||||
- **RuntimeError** - `index` 存在值超出范围[-x.shape[axis], x.shape[axis]-1]。
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
mindspore.ops.index_select
|
||||
==========================
|
||||
|
||||
.. py:function:: mindspore.ops.index_select(x, axis, index)
|
||||
|
||||
返回一个新的Tensor,该Tensor沿维度 `axis` 按 `index` 中给定的顺序对 `x` 进行选择。
|
||||
|
||||
返回的Tensor和输入Tensor( `x` )的维度数量相同,其第 `axis` 维度的大小和 `index` 的长度相同;其它维度和 `x` 相同。
|
||||
|
||||
.. note::
|
||||
index的值必须在 `[0, x.shape[axis])` 范围内,超出该范围结果未定义。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 输入Tensor。
|
||||
- **axis** (int) - `index` 的维度。
|
||||
- **index** (Tensor) - 包含索引的一维Tensor。数据类型为int32或int64。
|
||||
|
||||
返回:
|
||||
Tensor,数据类型与输入 `x` 相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 或 `index` 的类型不是Tensor。
|
||||
- **TypeError** - `axis` 的类型不是int。
|
||||
- **ValueError** - `axis` 值超出范围[-x.ndim, x.ndim - 1]。
|
||||
- **ValueError** - `index` 不是一维Tensor。
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.ops.lt
|
||||
================
|
||||
|
||||
.. py:function:: mindspore.ops.lt(x, other)
|
||||
|
||||
:func:`mindspore.ops.less` 的别名。
|
|
@ -144,6 +144,7 @@
|
|||
mindspore.Tensor.igammac
|
||||
mindspore.Tensor.index_add
|
||||
mindspore.Tensor.index_fill
|
||||
mindspore.Tensor.index_select
|
||||
mindspore.Tensor.init_data
|
||||
mindspore.Tensor.inner
|
||||
mindspore.Tensor.inplace_update
|
||||
|
@ -186,6 +187,7 @@
|
|||
mindspore.Tensor.logit
|
||||
mindspore.Tensor.logsumexp
|
||||
mindspore.Tensor.long
|
||||
mindspore.Tensor.lt
|
||||
mindspore.Tensor.masked_fill
|
||||
mindspore.Tensor.masked_select
|
||||
mindspore.Tensor.matrix_determinant
|
||||
|
|
|
@ -343,6 +343,7 @@ Comparison Functions
|
|||
mindspore.ops.is_complex
|
||||
mindspore.ops.le
|
||||
mindspore.ops.less
|
||||
mindspore.ops.lt
|
||||
mindspore.ops.maximum
|
||||
mindspore.ops.minimum
|
||||
mindspore.ops.ne
|
||||
|
@ -452,6 +453,7 @@ Array Operation
|
|||
mindspore.ops.hsplit
|
||||
mindspore.ops.index_add
|
||||
mindspore.ops.index_fill
|
||||
mindspore.ops.index_select
|
||||
mindspore.ops.inplace_add
|
||||
mindspore.ops.inplace_sub
|
||||
mindspore.ops.inplace_update
|
||||
|
|
|
@ -244,6 +244,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"cummin", std::string("cummin")}, // cummin()
|
||||
{"cummax", std::string("cummax")}, // cummax()
|
||||
{"index_fill", std::string("index_fill")}, // index_fill()
|
||||
{"index_select", std::string("index_select")}, // index_select()
|
||||
{"repeat_interleave", std::string("repeat_interleave")}, // repeat_interleave()
|
||||
{"copy", std::string("copy")}, // copy()
|
||||
{"copysign", std::string("copysign")}, // copysign()
|
||||
|
@ -429,6 +430,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"isnan", std::string("isnan")}, // isnan()
|
||||
{"le", std::string("le")}, // le()
|
||||
{"less", std::string("less")}, // less()
|
||||
{"lt", std::string("less")}, // lt()
|
||||
{"logical_and", std::string("logical_and")}, // logical_and()
|
||||
{"logical_not", std::string("logical_not")}, // logical_not()
|
||||
{"logical_or", std::string("logical_or")}, // logical_or()
|
||||
|
|
|
@ -992,12 +992,19 @@ def cummax(x, axis):
|
|||
return F.cummax(x, axis)
|
||||
|
||||
|
||||
def index_fill(x, dim, index, value):
|
||||
def index_fill(x, axis, index, value):
|
||||
"""
|
||||
Fills the elements under the dim dimension of the input Tensor with the input value
|
||||
Fills the elements under the axis dimension of the input Tensor with the input value
|
||||
by selecting the indices in the order given in index.
|
||||
"""
|
||||
return F.index_fill(x, dim, index, value)
|
||||
return F.index_fill(x, axis, index, value)
|
||||
|
||||
|
||||
def index_select(x, axis, index):
|
||||
"""
|
||||
Returns a new tensor which indexes the `x` tensor along dimension `axis` using the entries in `index` .
|
||||
"""
|
||||
return F.index_select(x, axis, index)
|
||||
|
||||
|
||||
def copy(x):
|
||||
|
|
|
@ -2086,11 +2086,18 @@ class Tensor(Tensor_):
|
|||
"""
|
||||
return tensor_operator_registry.get('cummax')(self, axis)
|
||||
|
||||
def index_fill(self, dim, index, value):
|
||||
def index_fill(self, axis, index, value):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.index_fill`.
|
||||
"""
|
||||
return tensor_operator_registry.get('index_fill')(self, dim, index, value)
|
||||
return tensor_operator_registry.get('index_fill')(self, axis, index, value)
|
||||
|
||||
def index_select(self, axis, index):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.index_select`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('index_select')(self, axis, index)
|
||||
|
||||
def inplace_update(self, v, indices):
|
||||
"""
|
||||
|
@ -3956,6 +3963,12 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('less')(self, other)
|
||||
|
||||
def lt(self, other):
|
||||
"""
|
||||
Alias for :func:`mindspore.Tensor.less`.
|
||||
"""
|
||||
return self.less(other)
|
||||
|
||||
def logical_and(self, other):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.logical_and`.
|
||||
|
|
|
@ -126,6 +126,7 @@ from .array_func import (
|
|||
hsplit,
|
||||
dsplit,
|
||||
index_fill,
|
||||
index_select,
|
||||
max,
|
||||
argmax,
|
||||
min,
|
||||
|
@ -163,6 +164,7 @@ from .math_func import (
|
|||
negative,
|
||||
tensor_lt,
|
||||
less,
|
||||
lt,
|
||||
tensor_le,
|
||||
le,
|
||||
lerp,
|
||||
|
|
|
@ -4266,14 +4266,14 @@ def unsorted_segment_prod(x, segment_ids, num_segments):
|
|||
return unsorted_segment_prod_(x, segment_ids, num_segments)
|
||||
|
||||
|
||||
def index_fill(x, dim, index, value):
|
||||
def index_fill(x, axis, index, value):
|
||||
"""
|
||||
Fills the elements under the `dim` dimension of the input Tensor `x` with the input `value`
|
||||
Fills the elements under the `axis` dimension of the input Tensor `x` with the input `value`
|
||||
by selecting the indices in the order given in `index`.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input Tensor. The supported data type is Number or Bool.
|
||||
dim (Union[int, Tensor]): Dimension along which to fill the input Tensor. Only supports
|
||||
axis (Union[int, Tensor]): Dimension along which to fill the input Tensor. Only supports
|
||||
an int number or a 0-dimensional Tensor, whose data type is int32 or int64.
|
||||
index (Tensor): Indices of the input Tensor to fill in. The dtype must be int32.
|
||||
value (Union[bool, int, float, Tensor]): Value to fill the returned Tensor. If `value` is
|
||||
|
@ -4285,17 +4285,17 @@ def index_fill(x, dim, index, value):
|
|||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If `dim` is neither int number nor Tensor.
|
||||
TypeError: When `dim` is a Tensor, its dtype is not int32 or int64.
|
||||
TypeError: If `axis` is neither int number nor Tensor.
|
||||
TypeError: When `axis` is a Tensor, its dtype is not int32 or int64.
|
||||
TypeError: If `index` is not a Tensor.
|
||||
TypeError: If dtype of `index` is not int32.
|
||||
TypeError: If `value` is not a bool, int, float, or Tensor.
|
||||
TypeError: When `value` is a Tensor, the dtype of `x` and `value` are not the same.
|
||||
ValueError: If `dim` is a Tensor and its rank is not equal to 0.
|
||||
ValueError: If `axis` is a Tensor and its rank is not equal to 0.
|
||||
ValueError: If the rank of `index` is greater than 1D.
|
||||
ValueError: When `value` is a Tensor and its rank is not equal to 0.
|
||||
RuntimeError: If the value of `dim` is out the range of `[-x.ndim, x.ndim - 1]`.
|
||||
RuntimeError: If the values of `index` are out the range of `[-x.shape[dim], x.shape[dim]-1]`.
|
||||
RuntimeError: If the value of `axis` is out the range of `[-x.ndim, x.ndim - 1]`.
|
||||
RuntimeError: If the values of `index` are out the range of `[-x.shape[axis], x.shape[axis]-1]`.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
@ -4314,11 +4314,66 @@ def index_fill(x, dim, index, value):
|
|||
[-2. 5. -2.]
|
||||
[-2. 8. -2.]]
|
||||
"""
|
||||
if isinstance(dim, int) and not isinstance(dim, bool):
|
||||
dim = cast_(dim, mstype.int32)
|
||||
if isinstance(axis, int) and not isinstance(axis, bool):
|
||||
axis = cast_(axis, mstype.int32)
|
||||
if isinstance(value, (bool, float, int)):
|
||||
value = cast_(value, x.dtype)
|
||||
return index_fill_(x, dim, index, value)
|
||||
return index_fill_(x, axis, index, value)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_check_axis_in_range(axis, ndim):
|
||||
"""Checks axes are with the bounds of ndim"""
|
||||
axis = validator.check_axis_in_range(axis, ndim)
|
||||
return axis
|
||||
|
||||
|
||||
def index_select(x, axis, index):
|
||||
"""
|
||||
Returns a new Tensor which indexes the `x` Tensor along dimension `axis` using the entries in `index` .
|
||||
|
||||
The returned Tensor has the same number of dimensions as the original Tensor ( `x` ). The `axis` th dimension
|
||||
has the same size as the length of `index` ; other dimensions have the same size as in the original Tensor.
|
||||
|
||||
.. note::
|
||||
The value of index must be in the range of `[0, x.shape[axis])`, the result is undefined out of range.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input Tensor.
|
||||
axis (int): Dimension in which we index.
|
||||
index (Tensor): The 1-D Tensor containing the indices to index. The data type can be int32 or int64.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same dtype as input Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` or `index` is not a Tensor.
|
||||
TypeError: If `axis` is not int number.
|
||||
ValueError: If the value of `axis` is out the range of `[-x.ndim, x.ndim - 1]`.
|
||||
ValueError: If the dimension of `index` is not equal to 1.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> from mindspore import Tensor, ops
|
||||
>>> import numpy as np
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32))
|
||||
>>> index = Tensor([0, 2], mindspore.int32)
|
||||
>>> y = ops.index_select(x, 1, index)
|
||||
>>> print(y)
|
||||
[[1. 3.]
|
||||
[4. 6.]
|
||||
[7. 9.]]
|
||||
"""
|
||||
if not (isinstance(x, Tensor) and isinstance(index, Tensor)):
|
||||
raise TypeError(f"For 'index_select', inputs `x` and `index` must be all tensors.")
|
||||
if index.ndim != 1:
|
||||
raise ValueError(f"For 'index_select', the dimension of `index` must be 1, but got {index.ndim}")
|
||||
axis = _check_check_axis_in_range(axis, x.ndim)
|
||||
return gather_(x, index, axis)
|
||||
|
||||
|
||||
def population_count(input_x):
|
||||
|
@ -6021,7 +6076,8 @@ __all__ = [
|
|||
'vsplit',
|
||||
'hsplit',
|
||||
'dsplit',
|
||||
"index_fill",
|
||||
'index_fill',
|
||||
'index_select',
|
||||
'max',
|
||||
'argmax',
|
||||
'min',
|
||||
|
|
|
@ -247,6 +247,9 @@ def abs(x):
|
|||
def absolute(x):
|
||||
"""
|
||||
Alias for :func:`mindspore.ops.abs` .
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
return abs(x)
|
||||
|
||||
|
@ -1333,6 +1336,9 @@ def floor(x):
|
|||
def i0(x):
|
||||
r"""
|
||||
Alias for :func:`mindspore.ops.bessel_i0` .
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
"""
|
||||
return bessel_i0(x)
|
||||
|
||||
|
@ -2125,6 +2131,9 @@ def acos(x):
|
|||
def arccos(x):
|
||||
"""
|
||||
Alias for :func:`mindspore.ops.acos` .
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
return acos(x)
|
||||
|
||||
|
@ -3468,6 +3477,16 @@ def less(x, y):
|
|||
return tensor_lt(x, y)
|
||||
|
||||
|
||||
def lt(x, other):
|
||||
"""
|
||||
Alias for :func:`mindspore.ops.less` .
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
return less(x, other)
|
||||
|
||||
|
||||
def le(x, y):
|
||||
r"""
|
||||
Computes the boolean value of :math:`x <= y` element-wise.
|
||||
|
@ -9660,6 +9679,7 @@ __all__ = [
|
|||
'negative',
|
||||
'tensor_lt',
|
||||
'less',
|
||||
'lt',
|
||||
'logaddexp2',
|
||||
'tensor_le',
|
||||
'lcm',
|
||||
|
|
|
@ -177,6 +177,7 @@ tensor_operator_registry.register('positive', positive)
|
|||
tensor_operator_registry.register('permute', permute)
|
||||
tensor_operator_registry.register('remainder', remainder)
|
||||
tensor_operator_registry.register('index_fill', index_fill)
|
||||
tensor_operator_registry.register('index_select', index_select)
|
||||
tensor_operator_registry.register('flip', flip)
|
||||
tensor_operator_registry.register('fliplr', fliplr)
|
||||
tensor_operator_registry.register('flipud', flipud)
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, axis, index):
|
||||
output = ops.index_select(x, axis, index)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_index_select(mode):
|
||||
"""
|
||||
Feature: index_select
|
||||
Description: Verify the result of index_select
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32))
|
||||
axis = 1
|
||||
index = ms.Tensor([0, 2], ms.int32)
|
||||
expect_output = np.array([[1., 3.], [4., 6.], [7., 9.]], dtype=np.float32)
|
||||
out = net(x, axis, index)
|
||||
assert np.allclose(out.asnumpy(), expect_output)
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, other):
|
||||
output = ops.lt(x, other)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_lt(mode):
|
||||
"""
|
||||
Feature: ops.lt
|
||||
Description: Verify the result of op.lt
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor(np.array([1, 2, 3]), ms.int32)
|
||||
y = ms.Tensor(np.array([1, 1, 4]), ms.int32)
|
||||
out = net(x, y)
|
||||
expect_output = np.array([False, False, True], np.bool_)
|
||||
assert (out.asnumpy() == expect_output).any()
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, axis, index):
|
||||
output = x.index_select(axis, index)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_index_select(mode):
|
||||
"""
|
||||
Feature: Tensor.index_select
|
||||
Description: Verify the result of index_select
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32))
|
||||
axis = 1
|
||||
index = ms.Tensor([0, 2], ms.int32)
|
||||
expect_output = np.array([[1., 3.], [4., 6.], [7., 9.]], dtype=np.float32)
|
||||
out = net(x, axis, index)
|
||||
assert np.allclose(out.asnumpy(), expect_output)
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, other):
|
||||
output = x.lt(other)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_lt(mode):
|
||||
"""
|
||||
Feature: Tensor.lt
|
||||
Description: Verify the result of Tensor.lt
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor(np.array([1, 2, 3]), ms.int32)
|
||||
y = ms.Tensor(np.array([1, 1, 4]), ms.int32)
|
||||
out = net(x, y)
|
||||
expect_output = np.array([False, False, True], np.bool_)
|
||||
assert (out.asnumpy() == expect_output).any()
|
Loading…
Reference in New Issue