!46231 [API] Add ops.is_complex, ops.float_power, ops.fmod, tensor.is_complex, tensor.float_power and tensor.fmod.

Merge pull request !46231 from DavidFFFan/api
This commit is contained in:
i-robot 2022-12-02 06:48:18 +00:00 committed by Gitee
commit 805e244f4d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
24 changed files with 614 additions and 7 deletions

View File

@ -135,6 +135,7 @@ mindspore.ops
mindspore.ops.igamma
mindspore.ops.igammac
mindspore.ops.is_floating_point
mindspore.ops.is_complex
mindspore.ops.pinv
逐元素运算
@ -199,6 +200,8 @@ mindspore.ops
mindspore.ops.floor
mindspore.ops.floor_div
mindspore.ops.floor_mod
mindspore.ops.float_power
mindspore.ops.fmod
mindspore.ops.heaviside
mindspore.ops.hypot
mindspore.ops.i0

View File

@ -0,0 +1,6 @@
mindspore.Tensor.float_power
============================
.. py:method:: mindspore.Tensor.float_power(exponent)
详情请参考 :func:`mindspore.ops.float_power`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.fmod
=====================
.. py:method:: mindspore.Tensor.fmod(other)
详情请参考 :func:`mindspore.ops.fmod`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.is_complex
===========================
.. py:method:: mindspore.Tensor.is_complex()
详情请参考 :func:`mindspore.ops.is_complex`

View File

@ -104,8 +104,10 @@ mindspore.Tensor
mindspore.Tensor.fliplr
mindspore.Tensor.flipud
mindspore.Tensor.float
mindspore.Tensor.float_power
mindspore.Tensor.floor
mindspore.Tensor.flush_from_cache
mindspore.Tensor.fmod
mindspore.Tensor.fold
mindspore.Tensor.from_numpy
mindspore.Tensor.gather
@ -135,6 +137,7 @@ mindspore.Tensor
mindspore.Tensor.invert
mindspore.Tensor.isclose
mindspore.Tensor.isfinite
mindspore.Tensor.is_complex
mindspore.Tensor.is_floating_point
mindspore.Tensor.isinf
mindspore.Tensor.isnan

View File

@ -0,0 +1,20 @@
mindspore.ops.float_power
==========================
.. py:function:: mindspore.ops.float_power(x, exponent)
计算 `x` 的指数幂。对于实数类型使用mindspore.float64计算。对于复数类型使用输入数据相同类型计算。
.. note::
目前GPU平台不支持数据类型complex。
参数:
- **x** (Union[Tensor, Number]) - 第一个输入为Tensor或数值型数据类型。
- **exponent** (Union[Tensor, Number]) - 第二个输入如果第一个输入是Tensor第二个输入可以是数值型或Tensor。否则必须是Tensor。
返回:
Tensor输出的shape与广播后的shape相同。对于复数运算返回类型和输入数据类型相同。对于实数运算返回类型为mindspore.float64。
异常:
- **TypeError** - `x``exponent` 都不是Tensor。
- **TypeError** - `x``exponent` 数据类型不是Tensor或Number。

View File

@ -0,0 +1,22 @@
mindspore.ops.fmod
===================
.. py:function:: mindspore.ops.fmod(x, other)
计算除法运算 x/other 的浮点余数。
.. math::
out = x - n * other
其中 :math:`n`:math:`x/other` 结果中的整数部分。
返回值的符号和 `x` 相同,在数值上小于 `other`
参数:
- **x** (Union[Tensor, Number]) - 被除数。
- **other** (Union[Tensor, Number]) - 除数。
返回:
Tensor输出的shape与广播后的shape相同数据类型取两个输入中精度较高或数字较高的。
异常:
- **TypeError** - `x``other` 都不是Tensor。

View File

@ -0,0 +1,15 @@
mindspore.ops.is_complex
=========================
.. py:function:: mindspore.ops.is_complex(x)
如果Tensor的数据类型是复数则返回True否则返回False。
参数:
- **x** (Tensor) - 输入Tensor。
返回:
Bool返回Tensor的数据类型是否为complex。
异常:
- **TypeError** - `x` 不是Tensor。

View File

@ -110,9 +110,11 @@
mindspore.Tensor.fliplr
mindspore.Tensor.flipud
mindspore.Tensor.float
mindspore.Tensor.float_power
mindspore.Tensor.floor
mindspore.Tensor.flush_from_cache
mindspore.Tensor.fold
mindspore.Tensor.fmod
mindspore.Tensor.from_numpy
mindspore.Tensor.gather
mindspore.Tensor.gather_elements
@ -148,6 +150,7 @@
mindspore.Tensor.isposinf
mindspore.Tensor.isreal
mindspore.Tensor.is_signed
mindspore.Tensor.is_complex
mindspore.Tensor.item
mindspore.Tensor.itemset
mindspore.Tensor.itemsize

View File

@ -200,6 +200,8 @@ Element-by-Element Operations
mindspore.ops.floor
mindspore.ops.floor_div
mindspore.ops.floor_mod
mindspore.ops.float_power
mindspore.ops.fmod
mindspore.ops.heaviside
mindspore.ops.hypot
mindspore.ops.i0
@ -297,6 +299,7 @@ Comparison Functions
mindspore.ops.isneginf
mindspore.ops.isposinf
mindspore.ops.isreal
mindspore.ops.is_complex
mindspore.ops.le
mindspore.ops.less
mindspore.ops.maximum

View File

@ -292,6 +292,7 @@ BuiltInTypeMap &GetMethodMap() {
{"isclose", std::string("isclose")}, // P.IsClose()
{"is_floating_point", std::string("is_floating_point")}, // is_floating_point()
{"is_signed", std::string("is_signed")}, // is_signed()
{"is_complex", std::string("is_complex")}, // F.is_complex()
{"inv", std::string("inv")}, // inv()
{"inverse", std::string("inverse")}, // inverse()
{"invert", std::string("invert")}, // invert()
@ -330,6 +331,8 @@ BuiltInTypeMap &GetMethodMap() {
{"flip", std::string("flip")}, // flip
{"fliplr", std::string("fliplr")}, // fliplr
{"flipud", std::string("flipud")}, // flipud
{"float_power", std::string("float_power")}, // F.float_power
{"fmod", std::string("fmod")}, // F.fmod
{"hardshrink", std::string("hardshrink")}, // P.hshrink
{"heaviside", std::string("heaviside")}, // F.heaviside
{"hypot", std::string("hypot")}, // F.hypot

View File

@ -1418,6 +1418,20 @@ def flipud(x):
return F.flipud(x)
def float_power(x, exponent):
"""
For details, please refer to :func:`mindspore.ops.float_power`.
"""
return F.float_power(x, exponent)
def fmod(x, other):
"""
For details, please refer to :func:`mindspore.ops.fmod`.
"""
return F.fmod(x, other)
def is_floating_point(x):
"""
For details, please refer to :func:`mindspore.ops.is_floating_point`.
@ -1432,6 +1446,13 @@ def is_signed(x):
return x.dtype in mstype.signed_type
def is_complex(x):
"""
For details, please refer to :func:`mindspore.ops.is_complex`.
"""
return F.is_complex(x)
def inv(x):
"""
Computes Reciprocal of input tensor element-wise.

View File

@ -146,6 +146,7 @@ float_type = (float16, float32, float64,)
signed_type = (int8, byte, int16, short, int32, intc, int64,
intp, float16, half, float32, single, float64,
double, complex64, complex128)
complex_type = (complex64, complex128,)
all_types = (bool_, int8, uint8, int16, int32, int64, float16, float32, float64, complex64, complex128)
implicit_conversion_seq = {t: idx for idx, t in enumerate(all_types)}

View File

@ -1517,6 +1517,13 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('isfinite')()(self)
def is_complex(self):
r"""
For details, please refer to :func:`mindspore.ops.is_complex`.
"""
self._init_check()
return tensor_operator_registry.get('is_complex')(self)
def inv(self):
r"""
For details, please refer to :func:`mindspore.ops.inv`.
@ -1842,6 +1849,20 @@ class Tensor(Tensor_):
perm = tuple(range(self.ndim - 1, -1, -1))
return reshape_op(trans_op(self, perm), (-1,))
def float_power(self, other):
r"""
For details, please refer to :func:`mindspore.ops.float_power`.
"""
self._init_check()
return tensor_operator_registry.get('float_power')(self, other)
def fmod(self, other):
r"""
For details, please refer to :func:`mindspore.ops.fmod`.
"""
self._init_check()
return tensor_operator_registry.get('fmod')(self, other)
def narrow(self, axis, start, length):
"""
For details, please refer to :func:`mindspore.ops.narrow`.

View File

@ -179,6 +179,8 @@ from .math_func import (
tensor_floordiv,
floor_div,
floordiv,
float_power,
fmod,
xdivy,
tensor_pow,
pow,
@ -282,6 +284,7 @@ from .math_func import (
exp2,
deg2rad,
isreal,
is_complex,
rad2deg,
truncate_div,
truncate_mod,

View File

@ -107,7 +107,7 @@ def clip_by_value(x, clip_value_min=None, clip_value_max=None):
"""
def _clip_by_value(clip_min, clip_max, x):
if not isinstance(x, Tensor):
TypeError("Then type of 'x' must be Tensor")
raise TypeError("Then type of 'x' must be Tensor")
result = x
if clip_min is not None:
result = max_op(result, cast_op(clip_min, x.dtype))
@ -116,13 +116,13 @@ def clip_by_value(x, clip_value_min=None, clip_value_max=None):
return result
if clip_value_min is None and clip_value_max is None:
ValueError("At least one of 'clip_value_min' or 'clip_value_max' must not be None")
raise ValueError("At least one of 'clip_value_min' or 'clip_value_max' must not be None")
if not isinstance(x, (Tensor, tuple, list)):
TypeError("The input of 'clip_by_value' must be tensor or tuple[Tensor] or list[Tensor]")
raise TypeError("The input of 'clip_by_value' must be tensor or tuple[Tensor] or list[Tensor]")
if not isinstance(clip_value_min, (type(None), Tensor, float, int)):
TypeError("Then type of 'clip_value_min' must be not one of None, Tensor, float, int.")
raise TypeError("Then type of 'clip_value_min' must be not one of None, Tensor, float, int.")
if not isinstance(clip_value_max, (type(None), Tensor, float, int)):
TypeError("Then type of 'clip_value_max' must be not one of None, Tensor, float, int.")
raise TypeError("Then type of 'clip_value_max' must be not one of None, Tensor, float, int.")
if isinstance(clip_value_min, (float, int)):
clip_value_min = scalar2tensor_op(clip_value_min)
if isinstance(clip_value_max, (float, int)):

View File

@ -16,6 +16,7 @@
"""Defines math operators with functional form."""
import math
import numbers
from itertools import zip_longest
from collections import deque
import numpy as np
@ -114,6 +115,7 @@ tensor_ge = P.GreaterEqual()
not_equal_ = P.NotEqual()
size_ = P.Size()
transpose_ = P.Transpose()
cast_ = P.Cast()
#####################################
# Private Operation Functions.
@ -869,6 +871,59 @@ def divide(x, other, *, rounding_mode=None):
return div(x, other, rounding_mode)
def float_power(x, exponent):
"""
Computes `x` to the power of the exponent.
For the real number type, use mindspore.float64 to calculate.
For the complex type, use the same type of calculation as the input data.
.. Note::
On GPU, complex dtypes are not supported.
Args:
x (Union[Tensor, Number]): The first input is a tensor or a number.
exponent (Union[Tensor, Number]): The second input, if the first input is Tensor,
the second input can be Number or Tensor. Otherwise, it must be a Tensor.
Returns:
Tensor, the shape is the same as the one after broadcasting. For the complex type,
the return value type is the same as the input type. For the real number type,
the return value type is mindspore.float64.
Raises:
TypeError: If neither `x` nor `exponent` is a Tensor.
TypeError: If the data type of `x` or `exponent` is not in Tensor and Number.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([-1.5, 0., 2.]))
>>> output = ops.float_power(x, 2)
>>> print(output)
[2.25 0. 4. ]
"""
if not (isinstance(x, (Tensor, Tensor_)) or isinstance(exponent, (Tensor, Tensor_))):
raise TypeError("At least one of the types of inputs must be tensor, " + \
f"but the type of 'x' got is {type(x)}, " + \
f"and the type of 'exponent' is {type(exponent)}.")
if not isinstance(x, (Tensor, Tensor_, numbers.Number)):
raise TypeError(f"The type of 'x' must be Tensor or Number, but got {type(x)}.")
if not isinstance(exponent, (Tensor, Tensor_, numbers.Number)):
raise TypeError(f"The type of 'exponent' must be Tensor or Number, but got {type(exponent)}.")
if isinstance(x, (Tensor, Tensor_)) and is_complex(x) and isinstance(exponent, numbers.Number):
exponent = cast_(exponent, x.dtype)
elif isinstance(exponent, (Tensor, Tensor_)) and is_complex(exponent) and isinstance(x, numbers.Number):
x = cast_(x, exponent.dtype)
# If both x and exponent are complex Tensor, no processing is required.
elif not (isinstance(x, (Tensor, Tensor_)) and is_complex(x) and \
isinstance(exponent, (Tensor, Tensor_)) and is_complex(exponent)):
x = cast_(x, mstype.float64)
exponent = cast_(exponent, mstype.float64)
return pow(x, exponent)
def floor_div(x, y):
"""
Divides the first input tensor by the second input tensor element-wise and round down to the closest integer.
@ -912,6 +967,44 @@ def floor_div(x, y):
return tensor_floordiv(x, y)
def fmod(x, other):
"""
Computes the floating-point remainder of the division operation x/other.
.. math::
out = x - n * other
Where :math:`n` is :math:`x/other` with its fractional part truncated.
The returned value has the same sign as `x` and is less than `other` in magnitude.
Args:
x (Union[Tensor, Number]): the dividend.
other (Union[Tensor, Number]): the divisor.
Returns:
Tensor, the shape is the same as the one after broadcasting,
and the data type is the one with higher precision or higher digits among the two inputs.
Raises:
TypeError: If neither `x` nor `other` is a Tensor.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> x = Tensor(np.array([-3., -2, -1, 1, 2, 3]), mindspore.float32)
>>> output = ops.fmod(x, 1.5)
>>> print(output)
[ 0. -0.5 -1. 1. 0.5 0. ]
"""
if not (isinstance(x, (Tensor, Tensor_)) or isinstance(other, (Tensor, Tensor_))):
raise TypeError("At least one of the types of inputs must be tensor, " + \
f"but the type of 'x' got is {type(x)}, " + \
f"and the type of 'other' is {type(other)}.")
return x - div(x, other, rounding_mode="trunc") * other
def pow(x, y):
r"""
Calculates the `y` power of each element in `x`.
@ -2840,7 +2933,7 @@ def ldexp(x, other):
Note:
Typically this function can create floating point numbers
by multiplying mantissas in input with powers of intger 2
by multiplying mantissas in input with powers of integer 2
from the exponents in `other`.
Args:
@ -3413,6 +3506,34 @@ def isreal(x):
return imag_op(x) == 0
def is_complex(x):
'''
Return True if the data type of the tensor is complex, otherwise return False.
Args:
x (Tensor) - The input tensor.
Returns:
Bool, return whether the data type of the tensor is complex.
Raises:
TypeError: If `x` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore import ops
>>> x = Tensor([1, 1+1j, 2+2j], mstype.complex64)
>>> output = ops.is_complex(x)
>>> print(output)
True
'''
if not isinstance(x, (Tensor, Tensor_)):
raise TypeError("The input x must be Tensor!")
return x.dtype in mstype.complex_type
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
"""
Replaces `NaN`, positive infinity, and negative infinity values in the `x` with the values
@ -7669,7 +7790,7 @@ def isinf(input):
def _is_sign_inf(x, fn):
"""Tests element-wise for inifinity with sign."""
"""Tests element-wise for infinity with sign."""
shape = x.shape
zeros_tensor = _get_cache_prim(P.Zeros)()(shape, mstype.float32)
ones_tensor = _get_cache_prim(P.Ones)()(shape, mstype.float32)
@ -7839,6 +7960,8 @@ __all__ = [
'tensor_floordiv',
'floor_div',
'floordiv',
'float_power',
'fmod',
'xdivy',
'tensor_pow',
'pow',
@ -7865,6 +7988,7 @@ __all__ = [
'isreal',
'isneginf',
'isposinf',
'is_complex',
'log',
'logdet',
'log_matrix_determinant',

View File

@ -175,6 +175,8 @@ tensor_operator_registry.register('index_fill', index_fill)
tensor_operator_registry.register('flip', flip)
tensor_operator_registry.register('fliplr', fliplr)
tensor_operator_registry.register('flipud', flipud)
tensor_operator_registry.register('float_power', float_power)
tensor_operator_registry.register('fmod', fmod)
tensor_operator_registry.register('is_floating_point', is_floating_point)
tensor_operator_registry.register('bitwise_and', bitwise_and)
tensor_operator_registry.register('bitwise_or', bitwise_or)
@ -352,6 +354,7 @@ tensor_operator_registry.register('equal', equal)
tensor_operator_registry.register('expm1', expm1)
tensor_operator_registry.register('isinf', isinf)
tensor_operator_registry.register('isnan', isnan)
tensor_operator_registry.register('is_complex', is_complex)
tensor_operator_registry.register('le', le)
tensor_operator_registry.register('less', less)
tensor_operator_registry.register('logical_and', logical_and)

View File

@ -0,0 +1,67 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x, exp):
output = ops.float_power(x, exp)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_float_power_real(mode):
"""
Feature: ops.float_power
Description: Verify the result of float_power
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
input_case = ms.Tensor(np.array([-3., -2, -1, 1, 2, 3]), ms.float32)
output_case = net(input_case, 2)
except_case = np.array([9.0000, 4.0000, 1.0000, 1.0000, 4.0000, 9.0000], dtype=np.float32)
assert output_case.asnumpy().dtype == np.float64
assert np.allclose(output_case.asnumpy(), except_case)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_float_power_complex(mode):
"""
Feature: ops.float_power
Description: Verify the result of float_power
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
input_case = ms.Tensor(np.array([complex(2, 3), complex(3, 4)]), ms.complex64)
output_case = net(input_case, 2)
except_case = np.array([complex(-5, 12), complex(-7, 24)], dtype=np.complex64)
assert np.allclose(output_case.asnumpy(), except_case)

View File

@ -0,0 +1,55 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x, other):
output = ops.fmod(x, other)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_fmod(mode):
"""
Feature: ops.fmod
Description: Verify the result of fmod
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
input_case_1 = ms.Tensor(np.array([-3., -2, -1, 1, 2, 3]), ms.float32)
output_case_1 = net(input_case_1, 1.5)
except_case_1 = np.array([-0.0000, -0.5000, -1.0000, 1.0000, 0.5000, 0.0000], dtype=np.float32)
assert np.allclose(output_case_1.asnumpy(), except_case_1)
input_case_2_1 = ms.Tensor(np.array([1, 2, 3, 4, 5]), ms.float32)
input_case_2_2 = ms.Tensor(np.array([-1.5, -0.5, 1.5, 2.5, 3.5]), ms.float32)
output_case_2 = net(input_case_2_1, input_case_2_2)
except_case_2 = np.array([1.0000, 0.0000, 0.0000, 1.5000, 1.5000], dtype=np.float32)
assert np.allclose(output_case_2.asnumpy(), except_case_2)

View File

@ -0,0 +1,53 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x):
output = ops.is_complex(x)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_is_complex(mode):
"""
Feature: ops.is_complex
Description: Verify the result of is_complex
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
a = ms.Tensor([complex(2, 3), complex(1, 3), complex(2.2, 3)], ms.complex64)
b = ms.Tensor(complex(2, 3), ms.complex128)
c = ms.Tensor([1, 2, 3], ms.float32)
out1 = net(a)
out2 = net(b)
out3 = net(c)
assert out1
assert out2
assert not out3

View File

@ -0,0 +1,65 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x, exp):
return x.float_power(exp)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_float_power_real(mode):
"""
Feature: tensor.float_power
Description: Verify the result of float_power
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
input_case = ms.Tensor(np.array([-3., -2, -1, 1, 2, 3]), ms.float32)
output_case = net(input_case, 2)
except_case = np.array([9.0000, 4.0000, 1.0000, 1.0000, 4.0000, 9.0000], dtype=np.float32)
assert output_case.asnumpy().dtype == np.float64
assert np.allclose(output_case.asnumpy(), except_case)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_float_power_complex(mode):
"""
Feature: tensor.float_power
Description: Verify the result of float_power
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
input_case = ms.Tensor(np.array([complex(2, 3), complex(3, 4)]), ms.complex64)
output_case = net(input_case, 2)
except_case = np.array([complex(-5, 12), complex(-7, 24)], dtype=np.complex64)
assert np.allclose(output_case.asnumpy(), except_case)

View File

@ -0,0 +1,53 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x, other):
return x.fmod(other)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_fmod(mode):
"""
Feature: tensor.fmod
Description: Verify the result of fmod
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
input_case_1 = ms.Tensor(np.array([-3., -2, -1, 1, 2, 3]), ms.float32)
output_case_1 = net(input_case_1, 1.5)
except_case_1 = np.array([-0.0000, -0.5000, -1.0000, 1.0000, 0.5000, 0.0000], dtype=np.float32)
assert np.allclose(output_case_1.asnumpy(), except_case_1)
input_case_2_1 = ms.Tensor(np.array([1, 2, 3, 4, 5]), ms.float32)
input_case_2_2 = ms.Tensor(np.array([-1.5, -0.5, 1.5, 2.5, 3.5]), ms.float32)
output_case_2 = net(input_case_2_1, input_case_2_2)
except_case_2 = np.array([1.0000, 0.0000, 0.0000, 1.5000, 1.5000], dtype=np.float32)
assert np.allclose(output_case_2.asnumpy(), except_case_2)

View File

@ -0,0 +1,51 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x):
return x.is_complex()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_is_complex(mode):
"""
Feature: tensor.is_complex
Description: Verify the result of is_complex
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
a = ms.Tensor([complex(2, 3), complex(1, 3), complex(2.2, 3)], ms.complex64)
b = ms.Tensor(complex(2, 3), ms.complex128)
c = ms.Tensor([1, 2, 3], ms.float32)
out1 = net(a)
out2 = net(b)
out3 = net(c)
assert out1
assert out2
assert not out3