!44947 support filp fliplr flipud isfloat issigned ops and tensor

Merge pull request !44947 from 冯一航/support_flip_isf_issigned
This commit is contained in:
i-robot 2022-11-08 02:50:29 +00:00 committed by Gitee
commit 2c6e29775f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
36 changed files with 1134 additions and 5 deletions

View File

@ -131,6 +131,8 @@ mindspore.ops.function
mindspore.ops.less_equal
mindspore.ops.igamma
mindspore.ops.igammac
mindspore.ops.is_floating_point
mindspore.ops.is_signed
逐元素运算
^^^^^^^^^^^^^
@ -336,6 +338,9 @@ Array操作
mindspore.ops.dyn_shape
mindspore.ops.expand
mindspore.ops.expand_dims
mindspore.ops.flip
mindspore.ops.fliplr
mindspore.ops.flipud
mindspore.ops.fold
mindspore.ops.gather
mindspore.ops.gather_d

View File

@ -0,0 +1,6 @@
mindspore.Tensor.flip
======================
.. py:method:: mindspore.Tensor.flip(dims)
详情请参考 :func:`mindspore.ops.flip`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.fliplr
========================
.. py:method:: mindspore.Tensor.fliplr()
详情请参考 :func:`mindspore.ops.fliplr`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.flipud
========================
.. py:method:: mindspore.Tensor.flipud()
详情请参考 :func:`mindspore.ops.flipud`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.is_floating_point
===================================
.. py:method:: mindspore.Tensor.is_floating_point()
详情请参考 :func:`mindspore.ops.is_floating_point`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.is_signed
===========================
.. py:method:: mindspore.Tensor.is_signed()
详情请参考 :func:`mindspore.ops.is_signed`

View File

@ -83,6 +83,9 @@ mindspore.Tensor
mindspore.Tensor.fill
mindspore.Tensor.fills
mindspore.Tensor.flatten
mindspore.Tensor.flip
mindspore.Tensor.fliplr
mindspore.Tensor.flipud
mindspore.Tensor.float
mindspore.Tensor.floor
mindspore.Tensor.flush_from_cache
@ -110,8 +113,10 @@ mindspore.Tensor
mindspore.Tensor.invert
mindspore.Tensor.isclose
mindspore.Tensor.isfinite
mindspore.Tensor.is_floating_point
mindspore.Tensor.isinf
mindspore.Tensor.isnan
mindspore.Tensor.is_signed
mindspore.Tensor.item
mindspore.Tensor.itemset
mindspore.Tensor.itemsize

View File

@ -0,0 +1,20 @@
mindspore.ops.flip
===================
.. py:function:: mindspore.ops.flip(x, dims)
沿给定轴翻转Tensor中元素的顺序。
Tensor的shape会被保留但是元素将重新排序。
参数:
- **x** (Tensor) - 输入tensor。
- **dims** (tuple[int]) - 需要翻转的一个轴或多个轴。在元组中指定的所有轴上执行翻转,如果 `dims` 是一个包含负数的整数元组,则该轴为按倒序计数的轴位置。
返回:
返回沿给定轴翻转计算结果的tensor。
异常:
- **TypeError** - `x` 不是Tensor。
- **ValueError** - `dims` 为None。
- **ValueError** - `dims` 不为int组成的tuple。

View File

@ -0,0 +1,16 @@
mindspore.ops.fliplr
=====================
.. py:function:: mindspore.ops.fliplr(x)
沿左右方向翻转Tensor中每行的元素。
Tensor的列会被保留但显示顺序将与以前不同。
参数:
- **x** (Tensor) - 输入tensor。
返回:
Tensor。
异常:
- **TypeError** - `x` 不是Tensor。

View File

@ -0,0 +1,16 @@
mindspore.ops.flipud
=====================
.. py:function:: mindspore.ops.flipud(x)
沿上下方向翻转Tensor中每行的元素。
Tensor的行会被保留但显示顺序将与以前不同。
参数:
- **x** (Tensor) - 输入tensor。
返回:
Tensor。
异常:
- **TypeError** - `x` 不是Tensor。

View File

@ -0,0 +1,12 @@
mindspore.ops.is_floating_point
================================
.. py:function:: mindspore.ops.is_floating_point(x)
判断 `x` 的dtype是否是浮点数据类型包括mindspore.flot64mindspore.float32mindspore.float16。
参数:
- **x** (Tensor) - 输入Tensor。
返回:
Bool如果 `x` 的dtype是浮点数据类型则返回True否则返回False。

View File

@ -0,0 +1,12 @@
mindspore.ops.is_signed
========================
.. py:function:: mindspore.ops.is_signed(x)
判断 `x` 的dtype是否是有符号数类型。
参数:
- **x** (Tensor) - 输入Tensor。
返回:
Bool如果 `x` 的dtype是有符号数类型则返回True否则返回False。

View File

@ -89,6 +89,9 @@
mindspore.Tensor.fill
mindspore.Tensor.fills
mindspore.Tensor.flatten
mindspore.Tensor.flip
mindspore.Tensor.fliplr
mindspore.Tensor.flipud
mindspore.Tensor.float
mindspore.Tensor.floor
mindspore.Tensor.flush_from_cache
@ -116,8 +119,10 @@
mindspore.Tensor.invert
mindspore.Tensor.isclose
mindspore.Tensor.isfinite
mindspore.Tensor.is_floating_point
mindspore.Tensor.isinf
mindspore.Tensor.isnan
mindspore.Tensor.is_signed
mindspore.Tensor.item
mindspore.Tensor.itemset
mindspore.Tensor.itemsize

View File

@ -132,6 +132,8 @@ Mathematical Functions
mindspore.ops.less_equal
mindspore.ops.igamma
mindspore.ops.igammac
mindspore.ops.is_floating_point
mindspore.ops.is_signed
Element-by-Element Operations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -336,6 +338,9 @@ Array Operation
mindspore.ops.dyn_shape
mindspore.ops.expand
mindspore.ops.expand_dims
mindspore.ops.flip
mindspore.ops.fliplr
mindspore.ops.flipud
mindspore.ops.fold
mindspore.ops.gather
mindspore.ops.gather_d

View File

@ -265,6 +265,8 @@ BuiltInTypeMap &GetMethodMap() {
{"choose", std::string("choose")}, // P.Select()
{"diagonal", std::string("diagonal")}, // P.Eye()
{"isclose", std::string("isclose")}, // P.IsClose()
{"is_floating_point", std::string("is_floating_point")}, // is_floating_point()
{"is_signed", std::string("is_signed")}, // is_signed()
{"inv", std::string("inv")}, // inv()
{"invert", std::string("invert")}, // invert()
{"searchsorted", std::string("searchsorted")}, // P.Select()
@ -296,6 +298,9 @@ BuiltInTypeMap &GetMethodMap() {
{"bernoulli", prim::kPrimBernoulli}, // P.Bernoulli()
{"ceil", std::string("ceil")}, // P.Ceil
{"floor", std::string("floor")}, // P.floor
{"flip", std::string("flip")}, // flip
{"fliplr", std::string("fliplr")}, // fliplr
{"flipud", std::string("flipud")}, // flipud
{"hardshrink", std::string("hardshrink")}, // P.hshrink
{"soft_shrink", std::string("soft_shrink")}, // P.SoftShrink
{"gather_nd", std::string("gather_nd")}, // P.GatherNd()

View File

@ -1231,6 +1231,41 @@ def isclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
return F.isclose(x1, x2, rtol, atol, equal_nan)
def flip(x, dims):
"""
For details, please refer to :func:`mindspore.ops.flip`.
"""
return F.flip(x, dims)
def fliplr(x):
"""
For details, please refer to :func:`mindspore.ops.fliplr`.
"""
return F.fliplr(x)
def flipud(x):
"""
For details, please refer to :func:`mindspore.ops.flipud`.
"""
return F.flipud(x)
def is_floating_point(x):
"""
For details, please refer to :func:`mindspore.ops.is_floating_point`.
"""
return F.is_floating_point(x)
def is_signed(x):
"""
For details, please refer to :func:`mindspore.ops.is_signed`.
"""
return F.is_signed(x)
def inv(x):
"""
Computes Reciprocal of input tensor element-wise.

View File

@ -3993,6 +3993,36 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('isnan')(self)
def flip(self, dims):
"""
For details, please refer to :func:`mindspore.ops.flip`.
"""
return tensor_operator_registry.get('flip')(self, dims)
def fliplr(self):
"""
For details, please refer to :func:`mindspore.ops.fliplr`.
"""
return tensor_operator_registry.get('fliplr')(self)
def flipud(self):
"""
For details, please refer to :func:`mindspore.ops.flipud`.
"""
return tensor_operator_registry.get('flipud')(self)
def is_floating_point(self):
"""
For details, please refer to :func:`mindspore.ops.is_floating_point`.
"""
return tensor_operator_registry.get('is_floating_point')(self)
def is_signed(self):
"""
For details, please refer to :func:`mindspore.ops.is_signed`.
"""
return tensor_operator_registry.get('is_signed')(self)
def le(self, other):
r"""
For details, please refer to :func:`mindspore.ops.le`.
@ -4306,7 +4336,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('ne')(self, other)
def sinh(self):
r"""
Computes hyperbolic sine of the input element-wise.
@ -4330,7 +4359,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('sinh')(self)
def sort(self, dim=-1, descending=False):
r"""
Sorts the elements of the input tensor along a given dimension in ascending order by value.
@ -4368,7 +4396,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('sort')(axis=dim, descending=descending)(self)
def trunc(self):
r"""
Returns a new tensor with the truncated integer values of the elements of input.
@ -4388,7 +4415,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('trunc')(self)
def imag(self):
r"""
Returns a new tensor containing imaginary value of the input tensor.

View File

@ -318,10 +318,15 @@ from .nn_func import (
dropout3d,
deformable_conv2d,
fast_gelu,
flip,
fliplr,
flipud,
pixel_shuffle,
pixel_unshuffle,
hardshrink,
soft_shrink,
is_floating_point,
is_signed,
intopk,
interpolate,
kl_div,

View File

@ -40,6 +40,9 @@ hardswish_ = P.HSwish()
mish_ = NN_OPS.Mish()
selu_ = NN_OPS.SeLU()
sigmoid_ = NN_OPS.Sigmoid()
signed_type = [mstype.int8, mstype.byte, mstype.int16, mstype.short, mstype.int32, mstype.intc, mstype.int64,
mstype.intp, mstype.float16, mstype.half, mstype.float32, mstype.single, mstype.float64,
mstype.double, mstype.complex64, mstype.complex128]
def adaptive_avg_pool2d(input_x, output_size):
@ -1352,6 +1355,237 @@ def hardshrink(x, lambd=0.5):
return hshrink_op(x)
@constexpr
def _check_axis_in_range(axis, ndim):
"""Checks axes are with the bounds of ndim"""
if not isinstance(axis, int):
raise TypeError(f'The dims must be integers, but got {type(axis)}')
if not -ndim <= axis < ndim:
raise ValueError(f"The 'axis' must be in the range of [-{ndim}, {ndim}), but got {axis}.")
return axis % ndim
@constexpr
def _check_axis_valid(axes, ndim):
"""
Checks axes are valid given ndim, and returns axes that can be passed
to the built-in operator (non-negative, int or tuple)
"""
if axes is None:
raise ValueError(f"The parameter dims can not be None.")
if isinstance(axes, (tuple, list)):
axes = tuple(map(lambda x: _check_axis_in_range(x, ndim), axes))
if any(axes.count(el) > 1 for el in axes):
raise ValueError(f"The element of parameter 'dims' can not be duplicate, but got {axes}.")
return axes
raise ValueError(f"The parameter dims must be tuple of ints, but got {type(axes)}")
@constexpr
def _get_flip_start(ndim, shape, axes):
"""Calculate the start index of flip"""
return tuple([shape[i] - 1 if i in axes else 0 for i in range(ndim)])
@constexpr
def _get_flip_end(ndim, shape, axes):
"""Calculate the end index of flip"""
return tuple([-shape[i] - 1 if i in axes else shape[i] + 1 for i in range(ndim)])
@constexpr
def _get_flip_strides(ndim, axes):
"""Calculate the strides of flip"""
return tuple([-1 if i in axes else 1 for i in range(ndim)])
@constexpr
def _is_shape_empty(shp):
"""Check whether shape contains zero"""
if isinstance(shp, int):
return shp == 0
return ops.shape_mul(shp) == 0
def _check_input_tensor(arg_name, *tensors):
"""Check whether the input is tensor"""
for tensor in tensors:
if not isinstance(tensor, Tensor):
raise TypeError(f"For '{arg_name}', the input must be Tensor, but got {ops.typeof(tensor)}")
return True
def flip(x, dims):
"""
Reverses the order of elements in a tensor along the given axis.
The shape of the tensor is preserved, but the elements are reordered.
Args:
x (Tensor): Input tensor.
dims (Union[list[int], tuple[int]]): Axis or axes along which to flip over.
Flipping is performed on all of the axes specified in the tuple,
If `dims` is a tuple of integers contains negative, it counts from the last to the first axis.
Returns:
Tensor, with the entries of `dims` reversed.
Raises:
TypeError: If the input is not a tensor.
ValueError: If `dims` is None.
ValueError: If `dims` is not a tuple of ints.
Supported Platforms:
``GPU`` ``CPU``
Example:
>>> import mindspore as ms
>>> import mindspore.ops as ops
>>> import numpy as np
>>> x = ms.Tensor(np.arange(8).reshape((2, 2, 2)))
>>> output = ops.flip(x, (0, 2))
>>> print(output)
[[[5. 4.]
[7. 6.]]
[[1. 0.]
[3. 2.]]]
"""
_check_input_tensor("flip", x)
ndim = ops.rank(x)
shape = ops.shape(x)
dims = _check_axis_valid(dims, ndim)
if _is_shape_empty(shape):
return x
start = _get_flip_start(ndim, shape, dims)
end = _get_flip_end(ndim, shape, dims)
strides = _get_flip_strides(ndim, dims)
res = ops.strided_slice(x, start, end, strides)
return res
def flipud(x):
"""
Flips the entries in each column in the up/down direction.
Rows are preserved, but appear in a different order than before.
Args:
x (Tensor): Input array.
Returns:
Tensor.
Raises:
TypeError: If the input is not a tensor.
Supported Platforms:
``GPU`` ``CPU``
Example:
>>> import mindspore as ms
>>> import mindspore.ops as ops
>>> import numpy as np
>>> x = ms.Tensor(np.arange(8).reshape((2, 2, 2)))
>>> output = ops.flipud(x)
>>> print(output)
[[[4. 5.]
[6. 7.]]
[[0. 1.]
[2. 3.]]]
"""
return flip(x, (0,))
def fliplr(x):
"""
Flips the entries in each row in the left/right direction.
Columns are preserved, but appear in a different order than before.
Args:
x (Tensor): Input tensor.
Returns:
Tensor.
Raises:
TypeError: If the input is not a tensor.
Supported Platforms:
``GPU`` ``CPU``
Example:
>>> import mindspore as ms
>>> import mindspore.ops as ops
>>> import numpy as np
>>> x = ms.Tensor(np.arange(8).reshape((2, 2, 2)))
>>> output = ops.fliplr(x)
>>> print(output)
[[[2. 3.]
[0. 1.]]
[[6. 7.]
[4. 5.]]]
"""
return flip(x, (1,))
def is_floating_point(x):
"""
Judge whether the data type of `x` is a floating point data type i.e., one of mindspore.flot64, mindspore.float32,
mindspore.float16.
Args:
x (Tensor): The input Tensor.
Returns:
Bool. If the dtype of `x` is a floating point data type, return True. Otherwise, return False.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore as ms
>>> import mindspore.ops as ops
>>> from mindspore import Tensor
>>> x = ms.Tensor([1, 2, 3], ms.float32)
>>> y = ms.Tensor([1, 2, 3], ms.int64)
>>> output = ops.is_floating_point(x)
>>> output2 = ops.is_floating_point(y)
>>> print(output)
True
>>> print(output2)
False
"""
return x.dtype in [mstype.float32, mstype.float16, mstype.float64]
def is_signed(x):
"""
Judge whether the data type of `x` is a signed data type.
Args:
x (Tensor): The input tensor.
Returns:
Bool. If the dtype of `x` is a signed data type, return True. Otherwise, return False.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore as ms
>>> import mindspore.ops as ops
>>> from mindspore import Tensor
>>> x = ms.Tensor([1, 2, 3], ms.int64)
>>> y = ms.Tensor([1, 2, 3], ms.uint64)
>>> output = ops.is_signed(x)
>>> output2 = ops.is_signed(y)
>>> print(output)
True
>>> print(output2)
False
"""
return x.dtype in signed_type
def hardswish(x):
r"""
Hard swish activation function.
@ -4358,6 +4592,11 @@ __all__ = [
'pixel_unshuffle',
'hardshrink',
'soft_shrink',
'is_floating_point',
'is_signed',
'flip',
'fliplr',
'flipud',
'intopk',
'interpolate',
'log_softmax',

View File

@ -370,6 +370,11 @@ tensor_operator_registry.register('cumsum', P.CumSum)
tensor_operator_registry.register('cummin', cummin)
tensor_operator_registry.register('cummax', cummax)
tensor_operator_registry.register('index_fill', index_fill)
tensor_operator_registry.register('flip', flip)
tensor_operator_registry.register('fliplr', fliplr)
tensor_operator_registry.register('flipud', flipud)
tensor_operator_registry.register('is_signed', is_signed)
tensor_operator_registry.register('is_floating_point', is_floating_point)
tensor_operator_registry.register('bitwise_and', bitwise_and)
tensor_operator_registry.register('bitwise_or', bitwise_or)
tensor_operator_registry.register('bitwise_xor', bitwise_xor)

50
tests/st/ops/test_flip.py Normal file
View File

@ -0,0 +1,50 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x):
output = ops.flip(x, (0, 2))
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_flip_normal(mode):
"""
Feature: flip
Description: Verify the result of flip
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor(np.arange(8).reshape((2, 2, 2)))
out = net(x)
expect_out = np.array([[[5., 4.],
[7., 6.]],
[[1., 0.],
[3., 2.]]])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,50 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x):
output = ops.fliplr(x)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_fliplr_normal(mode):
"""
Feature: fliplr
Description: Verify the result of fliplr
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor(np.arange(8).reshape((2, 2, 2)))
out = net(x)
expect_out = np.array([[[2., 3.],
[0., 1.]],
[[6., 7.],
[4., 5.]]])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,50 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x):
output = ops.flipud(x)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_flipud_normal(mode):
"""
Feature: flipud
Description: Verify the result of flipud
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor(np.arange(8).reshape((2, 2, 2)))
out = net(x)
expect_out = np.array([[[4., 5.],
[6., 7.]],
[[0., 1.],
[2., 3.]]])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,50 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x):
output = ops.is_floating_point(x)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_is_floating_point_normal(mode):
"""
Feature: is_floating_point
Description: Verify the result of is_floating_point
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor([1, 2, 3], ms.float32)
y = ms.Tensor([1, 2, 3], ms.int64)
out1 = net(x)
out2 = net(y)
assert out1
assert not out2

View File

@ -0,0 +1,50 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x):
output = ops.is_signed(x)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_is_signed_normal(mode):
"""
Feature: is_signed
Description: Verify the result of is_signed
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor([1, 2, 3], ms.int64)
y = ms.Tensor([1, 2, 3], ms.uint64)
out1 = net(x)
out2 = net(y)
assert out1
assert not out2

View File

@ -0,0 +1,49 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x):
output = x.flip((0, 2))
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_flip_normal(mode):
"""
Feature: tensor.flip
Description: Verify the result of flip
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor(np.arange(8).reshape((2, 2, 2)))
out = net(x)
expect_out = np.array([[[5., 4.],
[7., 6.]],
[[1., 0.],
[3., 2.]]])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,49 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x):
output = x.fliplr()
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_fliplr_normal(mode):
"""
Feature: tensor.fliplr
Description: Verify the result of fliplr
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor(np.arange(8).reshape((2, 2, 2)))
out = net(x)
expect_out = np.array([[[2., 3.],
[0., 1.]],
[[6., 7.],
[4., 5.]]])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,49 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x):
output = x.flipud()
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_flipud_normal(mode):
"""
Feature: tensor.flipud
Description: Verify the result of flipud
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor(np.arange(8).reshape((2, 2, 2)))
out = net(x)
expect_out = np.array([[[4., 5.],
[6., 7.]],
[[0., 1.],
[2., 3.]]])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,49 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x):
output = x.is_floating_point()
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_is_floating_point_normal(mode):
"""
Feature: tensor.is_floating_point
Description: Verify the result of is_floating_point
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor([1, 2, 3], ms.float32)
y = ms.Tensor([1, 2, 3], ms.int64)
out1 = net(x)
out2 = net(y)
assert out1
assert not out2

View File

@ -0,0 +1,49 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x):
output = x.is_signed()
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_is_signed_normal(mode):
"""
Feature: tensor.is_signed
Description: Verify the result of is_signed
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor([1, 2, 3], ms.int64)
y = ms.Tensor([1, 2, 3], ms.uint64)
out1 = net(x)
out2 = net(y)
assert out1
assert not out2

View File

@ -0,0 +1,39 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test flip api
"""
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.api import _cell_graph_executor
class Roll(nn.Cell):
def construct(self, x):
return ops.flip(x, (0, 2))
def test_compile_flip():
"""
Feature: Test filp
Description: Test the functionality of flip
Expectation: Success
"""
net = Roll()
x = ms.Tensor(np.arange(8).reshape((2, 2, 2)))
_cell_graph_executor.compile(net, x)

View File

@ -0,0 +1,39 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test fliplr api
"""
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.api import _cell_graph_executor
class Roll(nn.Cell):
def construct(self, x):
return ops.fliplr(x)
def test_compile_fliplr():
"""
Feature: Test filplr
Description: Test the functionality of fliplr
Expectation: Success
"""
net = Roll()
x = ms.Tensor(np.arange(8).reshape((2, 2, 2)))
_cell_graph_executor.compile(net, x)

View File

@ -0,0 +1,39 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test flipud api
"""
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.api import _cell_graph_executor
class Roll(nn.Cell):
def construct(self, x):
return ops.flipud(x)
def test_compile_flipud():
"""
Feature: Test flipud
Description: Test the functionality of flipud
Expectation: Success
"""
net = Roll()
x = ms.Tensor(np.arange(8).reshape((2, 2, 2)))
_cell_graph_executor.compile(net, x)

View File

@ -0,0 +1,38 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test is floating point api
"""
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.api import _cell_graph_executor
class Roll(nn.Cell):
def construct(self, x):
return ops.is_floating_point(x)
def test_compile_is_floating_point():
"""
Feature: Test is floating point
Description: Test the functionality of is floating point
Expectation: Success
"""
net = Roll()
x = ms.Tensor([1, 2, 3], ms.float32)
_cell_graph_executor.compile(net, x)

View File

@ -0,0 +1,38 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test is signed api
"""
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.api import _cell_graph_executor
class Roll(nn.Cell):
def construct(self, x):
return ops.is_signed(x)
def test_compile_is_signed():
"""
Feature: Test is signed
Description: Test the functionality of is signed
Expectation: Success
"""
net = Roll()
x = ms.Tensor([1, 2, 3], ms.int64)
_cell_graph_executor.compile(net, x)

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""
test pooling api
test roll api
"""
import numpy as np