!44610 tensor_nelement_numel_permute_positive_remainderr

Merge pull request !44610 from yide12/tensor_nelement_numel_permute_positive_remainder
This commit is contained in:
i-robot 2022-11-08 02:54:33 +00:00 committed by Gitee
commit 5fce0f724d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
27 changed files with 688 additions and 0 deletions

View File

@ -201,6 +201,7 @@ mindspore.ops.function
mindspore.ops.matrix_determinant
mindspore.ops.mul
mindspore.ops.neg
mindspore.ops.positive
mindspore.ops.pow
mindspore.ops.roll
mindspore.ops.round
@ -365,6 +366,8 @@ Array操作
mindspore.ops.meshgrid
mindspore.ops.normal
mindspore.ops.nonzero
mindspore.ops.numel
mindspore.ops.permute
mindspore.ops.population_count
mindspore.ops.range
mindspore.ops.rank

View File

@ -0,0 +1,8 @@
mindspore.Tensor.nelement
==========================
.. py:method:: mindspore.Tensor.nelement()
numel()的别名。
详情请参考 :func:`mindspore.ops.numel`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.numel
=======================
.. py:method:: mindspore.Tensor.numel()
详情请参考 :func:`mindspore.ops.numel`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.permute
=========================
.. py:method:: mindspore.Tensor.permute(*dims)
详情请参考 :func:`mindspore.ops.permute`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.positive
==========================
.. py:method:: mindspore.Tensor.positive()
详情请参考 :func:`mindspore.ops.positive`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.remainder
===========================
.. py:method:: mindspore.Tensor.remainder(divisor)
详情请参考 :func:`mindspore.ops.remainder`

View File

@ -151,13 +151,18 @@ mindspore.Tensor
mindspore.Tensor.ndim
mindspore.Tensor.ndimension
mindspore.Tensor.negative
mindspore.Tensor.nelement
mindspore.Tensor.numel
mindspore.Tensor.nonzero
mindspore.Tensor.norm
mindspore.Tensor.permute
mindspore.Tensor.positive
mindspore.Tensor.pow
mindspore.Tensor.prod
mindspore.Tensor.ptp
mindspore.Tensor.random_categorical
mindspore.Tensor.ravel
mindspore.Tensor.remainder
mindspore.Tensor.renorm
mindspore.Tensor.repeat
mindspore.Tensor.repeat_interleave

View File

@ -0,0 +1,12 @@
mindspore.ops.numel
====================
.. py:function:: mindspore.ops.numel(x)
返回Tensor的元素的总数量。
参数:
- **x** (Tensor) - 输入Tensor。
返回:
int。Tensor的元素的总数量。

View File

@ -0,0 +1,16 @@
mindspore.ops.permute
=====================
.. py:function:: mindspore.ops.permute(x, dims)
按照输入 `dims` 的维度顺序排列输入Tensor。
参数:
- **x** (Tensor) - 输入Tensor。
- **dims** (Union[tuple(int), list(int), int]) - 维度的顺序permute根据 `dims` 的顺序重新排列 `x`
返回:
Tensor具有和输入Tensor相同的维数按照 `dims` 重新排列。
异常:
- **ValueError** - `dims` 的元素总量不等于 `x` 的维数。

View File

@ -0,0 +1,15 @@
mindspore.ops.positive
======================
.. py:function:: mindspore.ops.positive(x)
返回输入Tensor。
参数:
- **x** (Tensor) - 输入Tensor。
返回:
输入Tensor。
异常:
- **TypeError** - `x` 的dtype是bool。

View File

@ -0,0 +1,28 @@
mindspore.ops.remainder
=======================
.. py:function:: mindspore.ops.remainder(x, y)
逐元素计算第一个元素除第二个元素的余数。
`x``y` 的输入遵守隐式类型转换规则以使数据类型一致。输入必须是两个Tensor或者一个Tensor和一个Scalar。当输入是两个Tensor时两个dtype都不能是bool类型shape可以广播。当输入是Tensor和Scalar时这个Scalar只能是常数。
.. math::
out_{i} = input_{i} \text{ % } other_{i}
.. warning::
- 输入数值不支持0。
- 当输入元素超过2048时操作的精确度无法保证mini表格的千分之二的要求。
- 由于架构不同该操作符在NPU和CPU上的计算结果可能不一致。
- 如果shape表示为(D1,D2…Dn)那么D1 \ * D2……\ * DN <= 1000000n <= 8。
参数:
- **x** (Union[Tensor, numbers.Number, bool]) - 第一个输入可以是数字bool或者dtype是数字的Tensor。
- **y** (Union[Tensor, numbers.Number, bool]) - 当第一个输入是一个Tensor时第二个输入可以是数字bool或者dtype是数字的Tensor。
返回:
Tensor具有和其中一个输入广播后相同的shape数据类型是两个输入中精度较高或数字较高的数据类型。
异常:
- **TypeError** - `x``y` 的类型不是Tensornumber或bool。
- **ValueError** - `x``y` 的shape不能广播成对方的shape。

View File

@ -157,13 +157,18 @@
mindspore.Tensor.ndim
mindspore.Tensor.ndimension
mindspore.Tensor.negative
mindspore.Tensor.nelement
mindspore.Tensor.numel
mindspore.Tensor.nonzero
mindspore.Tensor.norm
mindspore.Tensor.permute
mindspore.Tensor.positive
mindspore.Tensor.pow
mindspore.Tensor.prod
mindspore.Tensor.ptp
mindspore.Tensor.random_categorical
mindspore.Tensor.ravel
mindspore.Tensor.remainder
mindspore.Tensor.renorm
mindspore.Tensor.repeat
mindspore.Tensor.repeat_interleave

View File

@ -201,6 +201,7 @@ Element-by-Element Operations
mindspore.ops.matrix_determinant
mindspore.ops.mul
mindspore.ops.neg
mindspore.ops.positive
mindspore.ops.pow
mindspore.ops.roll
mindspore.ops.round
@ -364,6 +365,8 @@ Array Operation
mindspore.ops.meshgrid
mindspore.ops.normal
mindspore.ops.nonzero
mindspore.ops.numel
mindspore.ops.permute
mindspore.ops.population_count
mindspore.ops.range
mindspore.ops.rank

View File

@ -247,6 +247,11 @@ BuiltInTypeMap &GetMethodMap() {
{"min", std::string("min")}, // P.reduce_min()
{"pow", std::string("pow")}, // P.Pow()
{"log", std::string("log")}, // P.Log()
{"nelement", std::string("numel")}, // numel()
{"numel", std::string("numel")}, // numel()
{"permute", std::string("permute")}, // permute()
{"positive", std::string("positive")}, // positive()
{"remainder", std::string("remainder")}, // remainder()
{"minimum", std::string("minimum")}, // P.Minimum()
{"cosh", std::string("cosh")}, // P.Cosh()
{"tanh", std::string("tanh")}, // P.Tanh()

View File

@ -1080,6 +1080,38 @@ def rot90(x, k, dims):
return F.rot90(x, k, dims)
def numel(x):
"""
Returns a Scalar of type int that represents the total number of elements in the Tensor.
"""
return F.numel(x)
def permute(x, *dims):
"""
Permutes the dimensions of the input tensor according to input permutation.
"""
if dims is None:
raise ValueError(f"For Tensor.permute, the dims must not be none.")
if len(dims) == 1:
return F.permute(x, *dims)
return F.permute(x, dims)
def positive(x):
"""
Return self Tensor.
"""
return F.positive(x)
def remainder(x, divisor):
"""
Returns element-wise remainder of division.
"""
return F.remainder(x, divisor)
def unique_consecutive(x, return_idx=False, return_counts=False, axis=None):
"""
Returns the elements that are unique in each consecutive group of equivalent elements in the input tensor.

View File

@ -1778,6 +1778,47 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('rot90')(self, k, dims)
def nelement(self):
r"""
Alias for numel().
For details, please refer to :func:`mindspore.ops.numel`.
"""
self._init_check()
return tensor_operator_registry.get('nelement')(self)
def numel(self):
r"""
For details, please refer to :func:`mindspore.ops.numel`.
"""
self._init_check()
return tensor_operator_registry.get('numel')(self)
def permute(self, *dims):
"""
For details, please refer to :func:`mindspore.ops.permute`.
"""
self._init_check()
if not dims:
raise ValueError(f"For Tensor.permute, the dims must not be none.")
if len(dims) == 1:
return tensor_operator_registry.get("permute")(self, *dims)
return tensor_operator_registry.get("permute")(self, dims)
def positive(self):
"""
For details, please refer to :func:`mindspore.ops.positive`.
"""
self._init_check()
return tensor_operator_registry.get("positive")(self)
def remainder(self, divisor):
r"""
For details, please refer to :func:`mindspore.ops.remainder`.
"""
self._init_check()
return tensor_operator_registry.get('remainder')(self, divisor)
def flatten(self, order='C'):
r"""
For details, please refer to :func:`mindspore.ops.flatten`.

View File

@ -237,6 +237,9 @@ from .math_func import (
erfc,
cdist,
ceil,
positive,
numel,
permute,
bernoulli,
bessel_i0,
bessel_i0e,

View File

@ -104,6 +104,8 @@ tensor_le = P.LessEqual()
tensor_gt = P.Greater()
tensor_ge = P.GreaterEqual()
not_equal = P.NotEqual()
size_ = P.Size()
transpose_ = P.Transpose()
#####################################
# Private Operation Functions.
@ -465,6 +467,87 @@ def neg(x):
return neg_tensor(x)
def positive(x):
r"""
Return self Tensor.
Args:
x(Tensor): Input Tensor.
Returns:
Tensor, self input.
Raises:
TypeError: If the dtype of self Tensor is bool type.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.array([-5.0, 1.5, 3.0, 100.0]), ms.float32)
>>> print(ops.positive(x))
[-5.0, 1.5, 3.0, 100.0]
"""
if x.dtype == mstype.bool_:
raise TypeError("For positive, the type of tensor can not be bool.")
return x
def numel(x):
r"""
Returns a Scalar of type int that represents the total number of elements in the Tensor.
Args:
x (Tensor): Input Tensor.
Returns:
int. A scalar representing the total of elements in the Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
>>> print(ops.numel(input_x))
4
"""
return size_(x)
def permute(x, dims):
"""
Permutes the dimensions of the input tensor according to input `dims` .
Args:
x(Tensor): Input Tensor.
dims(Union[tuple(int), list(int), int]): Permute will permute the tensor to the input `dims` order.
Returns:
Tensor, has the same dimension as input tensor, with `dims` suitably permuted.
Raises:
ValueError: If `dims` is none.
ValueError: If the number of `dims` is not equal to Tensor's ndim.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]), mindspore.float32)
>>> input_perm = (0, 2, 1)
>>> print(ops.permute(input_x, input_perm))
[[[ 1. 4.]
[ 2. 5.]
[ 3. 6.]]
[[ 7. 10.]
[ 8. 11.]
[ 9. 12.]]]
"""
return transpose_(x, dims)
def ceil(x):
r"""
Rounds a tensor up to the closest integer element-wise.
@ -6970,6 +7053,8 @@ __all__ = [
'equal',
'not_equal',
'ne',
'numel',
'permute',
'inplace_update',
'inplace_add',
'inplace_sub',
@ -6987,6 +7072,7 @@ __all__ = [
'maximum',
'minimum',
'median',
'positive',
'floor',
'logical_not',
'logical_or',

View File

@ -369,6 +369,11 @@ tensor_operator_registry.register('argmax', P.Argmax)
tensor_operator_registry.register('cumsum', P.CumSum)
tensor_operator_registry.register('cummin', cummin)
tensor_operator_registry.register('cummax', cummax)
tensor_operator_registry.register('nelement', numel)
tensor_operator_registry.register('numel', numel)
tensor_operator_registry.register('positive', positive)
tensor_operator_registry.register('permute', permute)
tensor_operator_registry.register('remainder', remainder)
tensor_operator_registry.register('index_fill', index_fill)
tensor_operator_registry.register('flip', flip)
tensor_operator_registry.register('fliplr', fliplr)

View File

@ -0,0 +1,47 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, ops
class Net(nn.Cell):
def construct(self, x):
return ops.numel(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_ops_numel(mode):
"""
Feature: ops.numel
Description: Verify the result of numel
Expectation: success
"""
ms.set_context(mode=mode)
x = Tensor(np.arange(2 * 3 * 4).reshape((2, 3, 4)), ms.float32)
net = Net()
output = net(x)
expect_output = 24
assert np.allclose(output, expect_output)

View File

@ -0,0 +1,55 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, ops
class Net(nn.Cell):
def construct(self, x, dims):
return ops.permute(x, dims)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_ops_permute(mode):
"""
Feature: ops.permute
Description: Verify the result of permute
Expectation: success
"""
ms.set_context(mode=mode)
x = Tensor(np.arange(2 * 3 * 4).reshape(2, 3, 4), ms.float32)
input_perm = (0, 2, 1)
net = Net()
output = net(x, input_perm)
expect_output = [[[0, 4, 8],
[1, 5, 9],
[2, 6, 10],
[3, 7, 11]],
[[12, 16, 20],
[13, 17, 21],
[14, 18, 22],
[15, 19, 23]]]
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,47 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, ops
class Net(nn.Cell):
def construct(self, x):
return ops.positive(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_ops_positive(mode):
"""
Feature: ops.positive
Description: Verify the result of positive
Expectation: success
"""
ms.set_context(mode=mode)
x = Tensor(np.array([-5.0, 1.5, 3.0, 100.0]), ms.float32)
net = Net()
output = net(x)
expect_output = [-5.0, 1.5, 3.0, 100.0]
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,47 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
class Net(nn.Cell):
def construct(self, x):
return x.nelement()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_nelement(mode):
"""
Feature: tensor.nelement
Description: Verify the result of nelement
Expectation: success
"""
ms.set_context(mode=mode)
x = Tensor(np.arange(2 * 3 * 4).reshape((2, 3, 4)), ms.float32)
net = Net()
output = net(x)
expect_output = 24
assert np.allclose(output, expect_output)

View File

@ -0,0 +1,47 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
class Net(nn.Cell):
def construct(self, x):
return x.numel()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_numel(mode):
"""
Feature: tensor.numel
Description: Verify the result of numel
Expectation: success
"""
ms.set_context(mode=mode)
x = Tensor(np.arange(2 * 3 * 4).reshape((2, 3, 4)), ms.float32)
net = Net()
output = net(x)
expect_output = 24
assert np.allclose(output, expect_output)

View File

@ -0,0 +1,55 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
class Net(nn.Cell):
def construct(self, x, dims):
return x.permute(dims)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_permute(mode):
"""
Feature: tensor.permute
Description: Verify the result of permute
Expectation: success
"""
ms.set_context(mode=mode)
x = Tensor(np.arange(2 * 3 * 4).reshape(2, 3, 4), ms.float32)
input_perm = (0, 2, 1)
net = Net()
output = net(x, input_perm)
expect_output = [[[0, 4, 8],
[1, 5, 9],
[2, 6, 10],
[3, 7, 11]],
[[12, 16, 20],
[13, 17, 21],
[14, 18, 22],
[15, 19, 23]]]
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,47 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
class Net(nn.Cell):
def construct(self, x):
return x.positive()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_positive(mode):
"""
Feature: tensor.positive
Description: Verify the result of positive
Expectation: success
"""
ms.set_context(mode=mode)
x = Tensor(np.array([-5.0, 1.5, 3.0, 100.0]), ms.float32)
net = Net()
output = net(x)
expect_output = [-5.0, 1.5, 3.0, 100.0]
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,52 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
class Net(nn.Cell):
def construct(self, x, divisor):
return x.remainder(divisor)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_remainder(mode):
"""
Feature: tensor.remainder
Description: Verify the result of remainder
Expectation: success
"""
ms.set_context(mode=mode)
x = Tensor(np.array([-3, -2, -1, 1, 2, 3]), ms.float32)
net = Net()
output = net(x, -1.5)
expect_output1 = [0, -0.5, -1, -0.5, -1, 0]
assert np.allclose(output.asnumpy(), expect_output1)
x = Tensor(np.array([-30, -17, -3, 61, 17, 30]), ms.float32)
y = Tensor(np.array([-1.5, -2, -3.5, 1.5, 2, 3.5]), ms.float32)
output = net(x, y)
expect_output2 = [0, -1, -3, 1, 1, 2]
assert np.allclose(output.asnumpy(), expect_output2)