!43520 [ST][MS][OPS] greater, greater_equal, igamma, igammac Functional APIs; index_add, greater, greater_equal, igamma, igammac Tensor APIs and STs.

Merge pull request !43520 from alashkari/new-apis-oct-10
This commit is contained in:
i-robot 2022-10-27 09:06:22 +00:00 committed by Gitee
commit b12587f4d8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
19 changed files with 937 additions and 100 deletions

View File

@ -125,6 +125,8 @@ mindspore.ops.function
mindspore.ops.cumprod
mindspore.ops.erfinv
mindspore.ops.less_equal
mindspore.ops.igamma
mindspore.ops.igammac
逐元素运算
^^^^^^^^^^^^^
@ -235,6 +237,8 @@ Reduction函数
mindspore.ops.approximate_equal
mindspore.ops.equal
mindspore.ops.ge
mindspore.ops.greater
mindspore.ops.greater_equal
mindspore.ops.gt
mindspore.ops.intopk
mindspore.ops.isclose

View File

@ -55,6 +55,8 @@ mindspore.Tensor
mindspore.Tensor.equal
mindspore.Tensor.expm1
mindspore.Tensor.less_equal
mindspore.Tensor.igamma
mindspore.Tensor.igammac
逐元素运算
^^^^^^^^^^^^^
@ -140,6 +142,8 @@ Reduction方法
mindspore.Tensor.any
mindspore.Tensor.approximate_equal
mindspore.Tensor.ge
mindspore.Tensor.greater
mindspore.Tensor.greater_equal
mindspore.Tensor.gt
mindspore.Tensor.has_init
mindspore.Tensor.isclose
@ -205,6 +209,7 @@ Array操作
mindspore.Tensor.gather
mindspore.Tensor.gather_elements
mindspore.Tensor.gather_nd
mindspore.Tensor.index_add
mindspore.Tensor.index_fill
mindspore.Tensor.init_data
mindspore.Tensor.inplace_update

View File

@ -60,6 +60,8 @@ Mathematical Methods
mindspore.Tensor.equal
mindspore.Tensor.expm1
mindspore.Tensor.less_equal
mindspore.Tensor.igamma
mindspore.Tensor.igammac
Element-wise Methods
^^^^^^^^^^^^^^^^^^^^
@ -145,6 +147,8 @@ Comparison Methods
mindspore.Tensor.any
mindspore.Tensor.approximate_equal
mindspore.Tensor.ge
mindspore.Tensor.greater
mindspore.Tensor.greater_equal
mindspore.Tensor.gt
mindspore.Tensor.has_init
mindspore.Tensor.isclose
@ -210,6 +214,7 @@ Array Methods
mindspore.Tensor.gather
mindspore.Tensor.gather_elements
mindspore.Tensor.gather_nd
mindspore.Tensor.index_add
mindspore.Tensor.index_fill
mindspore.Tensor.init_data
mindspore.Tensor.inplace_update

View File

@ -126,6 +126,8 @@ Mathematical Functions
mindspore.ops.cumprod
mindspore.ops.erfinv
mindspore.ops.less_equal
mindspore.ops.igamma
mindspore.ops.igammac
Element-by-Element Operations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -235,6 +237,8 @@ Comparison Functions
mindspore.ops.approximate_equal
mindspore.ops.equal
mindspore.ops.ge
mindspore.ops.greater
mindspore.ops.greater_equal
mindspore.ops.gt
mindspore.ops.intopk
mindspore.ops.isclose

View File

@ -339,6 +339,11 @@ BuiltInTypeMap &GetMethodMap() {
{"equal", std::string("equal")}, // equal()
{"expm1", std::string("expm1")}, // expm1()
{"dim", prim::kPrimRank}, // P.Rank()
{"index_add", std::string("index_add")}, // index_add()
{"greater", std::string("greater")}, // greater()
{"greater_equal", std::string("greater_equal")}, // greater_equal()
{"igamma", std::string("igamma")}, // igamma()
{"igammac", std::string("igammac")}, // igammac()
}},
{kObjectTypeRowTensorType,
{

View File

@ -26,6 +26,7 @@ from mindspore.ops.composite.base import _append, _insert, _pop, _list_clear, _r
_count, _extend, _dict_clear, _haskey, _update, _fromkeys
from ..._checkparam import Validator as validator
from ..._checkparam import check_is_number
from ...ops import functional as F
from ...ops import operations as P
from ...ops.composite import tail, MultitypeFuncGraph, env_get, hyper_add, \
@ -3451,3 +3452,45 @@ def expm1(input_x):
Computes exponential then minus 1 of a tensor element-wise.
"""
return F.expm1(input_x)
@constexpr
def _check_index_add_alpha(alpha):
check_is_number(alpha, (int, float))
def index_add(input, dim, index, source, *, alpha=1):
r"""
Adds tensor `alpha` times `source` to specified `dim` and `index` of input tensor.
"""
_check_index_add_alpha(alpha)
source = source * alpha
return F.index_add(input, indices=index, y=source, axis=dim)
def greater(input, other):
r"""
Computes the boolean value of :math:`input > other` element-wise.
"""
return F.greater(input, other)
def greater_equal(input, other):
r"""
Computes the boolean value of :math:`input >= other` element-wise.
"""
return F.greater_equal(input, other)
def igamma(input, other):
r"""
Computes lower regularized incomplete Gamma function.
"""
return F.igamma(input, other)
def igammac(input, other):
r"""
Computes upper regularized incomplete Gamma function.
"""
return F.igammac(input, other)

View File

@ -34,7 +34,7 @@ from mindspore._c_expression import COOTensor as COOTensor_
from mindspore._c_expression import CSRTensor as CSRTensor_
from mindspore._c_expression import RowTensor as RowTensor_
from mindspore._c_expression import Tensor as Tensor_
from mindspore._checkparam import Rel
from mindspore._checkparam import Rel, check_is_number
from mindspore._checkparam import Validator as validator
np_types = (np.int8, np.int16, np.int32, np.int64,
@ -6968,6 +6968,200 @@ class Tensor(Tensor_):
return tensor_operator_registry.get('expm1')(self)
def index_add(self, dim, index, source, *, alpha=1):
r"""
Adds tensor `alpha` times `source` to specified `dim` and `index` of input tensor.
The `dim` should be in [0, len(x.dim) - 1], and `dim` should be in [0, the size of input tensor - 1]
at the axis dimension.
Args:
dim (int): The dimension along which to index.
index (Tensor): Add the value of input tensor and `source` along the dimension of the `dim` according to
the specified index value, with data type int32.
The `index` must be 1D with the same size as the size of `source` in the `dim` dimension. The values
of `index` should be in [0, b), where the b is the size of input tensor in the `dim` dimension.
source (Tensor): The input tensor with the value to add. Must have same data type as input tensor.
The shape must be the same as input tensor except the `dim` th dimension.
alpha (number.Number): the scalar multiplier for `source`. Default: 1.
Returns:
Tensor, has the same shape and dtype as input tensor.
Raises:
TypeError: If neither input tensor is not a Parameter.
TypeError: If neither `index` nor `source` is a Tensor.
ValueError: If `dim` is out of input tensor rank's range.
ValueError: If input tensor rank is not the same as `source` rank.
ValueError: If shape of `index` is not 1D or size of `index` is not equal to dimension of source[dim].
ValueError: If `source`'s shape is not the same as input tensor except the `dim` th dimension.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor, Parameter
>>> from mindspore import ops
>>> x = Parameter(Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32), name="name_x")
>>> index = Tensor(np.array([0, 2]), mindspore.int32)
>>> source = Tensor(np.array([[0.5, 1.0], [1.0, 1.5], [2.0, 2.5]]), mindspore.float32)
>>> output = x.index_add(dim=1, index=index, source=source)
>>> print(output)
[[1.5 2. 4. ]
[5. 5. 7.5]
[9. 8. 11.5]]
"""
self._init_check()
check_is_number(alpha, (int, float))
source = tensor_operator_registry.get('__mul__')(source, alpha)
return tensor_operator_registry.get('index_add')(self, indices=index, y=source, axis=dim)
def greater(self, other):
r"""
Computes the boolean value of :math:`input > other` element-wise.
Args:
other (Union[Tensor, number.Number, bool]): The second input, when the first input is a Tensor,
the second input should be a number.Number or bool value, or a Tensor whose data type is
number or bool\_. When the first input is Scalar, the second input must be a Tensor whose
data type is number or bool\_.
Returns:
Tensor, the shape is the same as the one after broadcasting, and the data type is bool.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> x = Tensor(np.array([1, 2, 3]), mindspore.int32)
>>> y = Tensor(np.array([1, 1, 4]), mindspore.int32)
>>> output = x.greater(y)
>>> print(output)
[False True False]
"""
self._init_check()
return tensor_operator_registry.get('greater')(self, other)
def greater_equal(self, other):
r"""
Computes the boolean value of :math:`input >= other` element-wise.
Args:
other (Union[Tensor, number.Number, bool]): The second input, when the first input is a Tensor,
the second input should be a number.Number or bool value, or a Tensor whose data type is
number or bool\_. When the first input is Scalar, the second input must be a Tensor whose
data type is number or bool\_.
Returns:
Tensor, the shape is the same as the one after broadcasting, and the data type is bool.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> x = Tensor(np.array([1, 2, 3]), mindspore.int32)
>>> y = Tensor(np.array([1, 1, 4]), mindspore.int32)
>>> output = x.greater_equal(y)
>>> print(output)
[True True False]
"""
self._init_check()
return tensor_operator_registry.get('greater_equal')(self, other)
def igamma(self, other):
r"""
Calculates lower regularized incomplete Gamma function.
If we define input tensor as `a` and `other` as `x`, the lower regularized incomplete Gamma function
is defined as:
.. math::
P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)
where
.. math::
gamma(a, x) = \int_0^x t^{a-1} \exp^{-t} dt
is the lower incomplete Gamma function.
Above :math:`Q(a, x)` is the upper regularized complete Gamma function.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Args:
other (Tensor): The second input tensor. With float32 or float64 type. `other` should have
the same dtype with `input`.
Outputs:
Tensor, has the same dtype as `input` and `other`.
Raises:
TypeError: If `other` is not a Tensor.
TypeError: If dtype of input `other` and a is not float32 nor float64.
TypeError: If `other` has different dtype with input tensor.
ValueError: If input tensor could not be broadcast to a tensor with shape of `other`.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
>>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
>>> output = ops.igamma(a, x)
>>> print(output)
[0.593994 0.35276785 0.21486944 0.13337152]
"""
self._init_check()
return tensor_operator_registry.get('igamma')(self, other)
def igammac(self, other):
r"""
Calculates upper regularized incomplete Gamma function.
If we define `input` as `a` and `other` as `x`, the upper regularized incomplete Gamma function is defined as:
\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\)
where
\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\)
is the upper incomplete Gama function.
Note, above P(a, x) (Igamma) is the lower regularized complete Gamma function.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Args:
other (Tensor): The second input tensor. With float32 or float64 type. `other` should have
the same dtype with `input`.
Outputs:
Tensor, has the same dtype as `input` and `other`.
Raises:
TypeError: If `other` is not a Tensor.
TypeError: If dtype of input `other` and a is not float32 nor float64.
TypeError: If `other` has different dtype with input tensor.
ValueError: If input tensor could not be broadcast to a tensor with shape of `other`.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
>>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
>>> output = ops.igammac(a, x)
>>> print (output)
[0.40600586 0.6472318 0.7851304 0.8666283]
"""
self._init_check()
return tensor_operator_registry.get('igammac')(self, other)
class RowTensor(RowTensor_):
"""
A sparse representation of a set of tensor slices at given indices.

View File

@ -285,6 +285,10 @@ from .math_func import (
erfinv,
less_equal,
cumprod,
greater,
greater_equal,
igamma,
igammac,
)
from .nn_func import (
adaptive_avg_pool1d,

View File

@ -56,6 +56,8 @@ from mindspore.ops.operations.math_ops import (
Sinc,
SparseSegmentMean,
InplaceUpdateV2,
Igamma,
Igammac,
)
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator as validator
@ -6578,6 +6580,157 @@ def cumprod(input, dim, dtype=None):
return output
def greater(input, other):
r"""
Computes the boolean value of :math:`input > other` element-wise.
Args:
input (Union[Tensor, number.Number, bool]): The first input is a number.Number or
a bool or a tensor whose data type is
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ or
`bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ .
other (Union[Tensor, number.Number, bool]): The second input, when the first input is a Tensor,
the second input should be a number.Number or bool value, or a Tensor whose data type is number or bool\_.
When the first input is Scalar, the second input must be a Tensor whose data type is number or bool\_.
Returns:
Tensor, the shape is the same as the one after broadcasting, and the data type is bool.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> x = Tensor(np.array([1, 2, 3]), mindspore.int32)
>>> y = Tensor(np.array([1, 1, 4]), mindspore.int32)
>>> output = ops.greater(x, y)
>>> print(output)
[False True False]
"""
greater_op = _get_cache_prim(P.Greater)()
return greater_op(input, other)
def greater_equal(input, other):
r"""
Computes the boolean value of :math:`input >= other` element-wise.
Args:
input (Union[Tensor, number.Number, bool]): The first input is a number.Number or
a bool or a tensor whose data type is
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ or
`bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ .
other (Union[Tensor, number.Number, bool]): The second input, when the first input is a Tensor,
the second input should be a number.Number or bool value, or a Tensor whose data type is number or bool\_.
When the first input is Scalar, the second input must be a Tensor whose data type is number or bool\_.
Returns:
Tensor, the shape is the same as the one after broadcasting, and the data type is bool.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> x = Tensor(np.array([1, 2, 3]), mindspore.int32)
>>> y = Tensor(np.array([1, 1, 4]), mindspore.int32)
>>> output = ops.greater_equal(x, y)
>>> print(output)
[True True False]
"""
greater_equal_op = _get_cache_prim(P.GreaterEqual)()
return greater_equal_op(input, other)
def igamma(input, other):
r"""
Calculates lower regularized incomplete Gamma function.
If we define `input` as `a` and `other` as `x`, the lower regularized incomplete Gamma function is defined as:
.. math::
P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)
where
.. math::
gamma(a, x) = \int_0^x t^{a-1} \exp^{-t} dt
is the lower incomplete Gamma function.
Above :math:`Q(a, x)` is the upper regularized complete Gamma function.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Args:
input (Tensor): The first input tensor. With type of float32 or float64.
other (Tensor): The second input tensor. With float32 or float64 type. `other` should have
the same dtype with `input`.
Outputs:
Tensor, has the same dtype as `input` and `other`.
Raises:
TypeError: If `input` or `other` is not a Tensor.
TypeError: If dtype of input `other` and a is not float32 nor float64.
TypeError: If `other` has different dtype with `input`.
ValueError: If `input` could not be broadcast to a tensor with shape of `other`.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
>>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
>>> output = ops.igamma(a, x)
>>> print(output)
[0.593994 0.35276785 0.21486944 0.13337152]
"""
igamma_op = _get_cache_prim(Igamma)()
return igamma_op(input, other)
def igammac(input, other):
r"""
Calculates upper regularized incomplete Gamma function.
If we define `input` as `a` and `other` as `x`, the upper regularized incomplete Gamma function is defined as:
\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\)
where
\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\)
is the upper incomplete Gama function.
Note, above P(a, x) (Igamma) is the lower regularized complete Gamma function.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs:
input (Tensor): The first input tensor. With type of float32 or float64.
other (Tensor): The second input tensor. With float32 or float64 type. `other` should have
the same dtype with `input`.
Outputs:
Tensor, has the same dtype as `input` and `other`.
Raises:
TypeError: If `input` or `other` is not a Tensor.
TypeError: If dtype of input `other` and a is not float32 nor float64.
TypeError: If `other` has different dtype with `input`.
ValueError: If `input` could not be broadcast to a tensor with shape of `other`.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
>>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
>>> output = ops.igammac(a, x)
>>> print (output)
[0.40600586 0.6472318 0.7851304 0.8666283]
"""
igammac_op = _get_cache_prim(Igammac)()
return igammac_op(input, other)
__all__ = [
'addn',
'absolute',
@ -6734,5 +6887,9 @@ __all__ = [
'erfinv',
'less_equal',
'cumprod',
'greater',
'greater_equal',
'igamma',
'igammac',
]
__all__.sort()

View File

@ -422,6 +422,11 @@ tensor_operator_registry.register('erfinv', erfinv)
tensor_operator_registry.register('less_equal', less_equal)
tensor_operator_registry.register('fold', fold)
tensor_operator_registry.register('unfold', unfold)
tensor_operator_registry.register('index_add', index_add)
tensor_operator_registry.register('greater', greater)
tensor_operator_registry.register('greater_equal', greater_equal)
tensor_operator_registry.register('igamma', igamma)
tensor_operator_registry.register('igammac', igammac)
# ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)

View File

@ -0,0 +1,102 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore import Tensor
import mindspore.context as context
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_greater_equal_functional_api_modes(mode):
"""
Feature: Test greater_equal functional api.
Description: Test greater_equal functional api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="Ascend")
x = Tensor(np.array([1, 2, 3]), mstype.int32)
y = Tensor(np.array([1, 1, 4]), mstype.int32)
output = F.greater_equal(x, y)
expected = np.array([True, True, False])
np.testing.assert_array_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_greater_equal_tensor_api_modes(mode):
"""
Feature: Test greater_equal tensor api.
Description: Test greater_equal tensor api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="Ascend")
x = Tensor(np.array([1, 2, 3]), mstype.int32)
y = Tensor(np.array([1, 1, 4]), mstype.int32)
output = x.greater_equal(y)
expected = np.array([True, True, False])
np.testing.assert_array_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_greater_functional_api_modes(mode):
"""
Feature: Test greater functional api.
Description: Test greater functional api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="Ascend")
x = Tensor(np.array([1, 2, 3]), mstype.int32)
y = Tensor(np.array([1, 1, 4]), mstype.int32)
output = F.greater(x, y)
expected = np.array([False, True, False])
np.testing.assert_array_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_greater_tensor_api_modes(mode):
"""
Feature: Test greater tensor api.
Description: Test greater tensor api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="Ascend")
x = Tensor(np.array([1, 2, 3]), mstype.int32)
y = Tensor(np.array([1, 1, 4]), mstype.int32)
output = x.greater(y)
expected = np.array([False, True, False])
np.testing.assert_array_equal(output.asnumpy(), expected)

View File

@ -21,6 +21,7 @@ import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore.common import dtype as mstype
class NetIndexAdd(nn.Cell):
@ -291,3 +292,24 @@ def test_index_add_dynamic_indices():
net.set_inputs(idx_dyn, Tensor(y))
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_index_add_tensor_api_modes(mode):
"""
Feature: Test index_add tensor api.
Description: Test index_add tensor api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="Ascend")
x = Parameter(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float32), name="name_x")
index = Tensor([0, 2], mstype.int32)
source = Tensor([[0.5, 1.0], [1.0, 1.5], [2.0, 2.5]], mstype.float32)
dim = 1
output = x.index_add(dim, index, source)
expected = np.array([[1.5, 2., 4.], [5., 5., 7.5], [9., 8., 11.5]], np.float32)
np.testing.assert_array_equal(output.asnumpy(), expected)

View File

@ -18,6 +18,8 @@ import numpy as np
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
@ -80,3 +82,79 @@ def test_greater_greaterequal_ops_infer_value_shape():
assert greater_out_shape == (2, 3)
assert not greater_equal_out
assert greater_equal_out_shape == (2, 3)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_greater_equal_functional_api_modes(mode):
"""
Feature: Test greater_equal functional api.
Description: Test greater_equal functional api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="CPU")
x = Tensor(np.array([1, 2, 3]), mstype.int32)
y = Tensor(np.array([1, 1, 4]), mstype.int32)
output = F.greater_equal(x, y)
expected = np.array([True, True, False])
np.testing.assert_array_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_greater_equal_tensor_api_modes(mode):
"""
Feature: Test greater_equal tensor api.
Description: Test greater_equal tensor api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="CPU")
x = Tensor(np.array([1, 2, 3]), mstype.int32)
y = Tensor(np.array([1, 1, 4]), mstype.int32)
output = x.greater_equal(y)
expected = np.array([True, True, False])
np.testing.assert_array_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_greater_functional_api_modes(mode):
"""
Feature: Test greater functional api.
Description: Test greater functional api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="CPU")
x = Tensor(np.array([1, 2, 3]), mstype.int32)
y = Tensor(np.array([1, 1, 4]), mstype.int32)
output = F.greater(x, y)
expected = np.array([False, True, False])
np.testing.assert_array_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_greater_tensor_api_modes(mode):
"""
Feature: Test greater tensor api.
Description: Test greater tensor api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="CPU")
x = Tensor(np.array([1, 2, 3]), mstype.int32)
y = Tensor(np.array([1, 1, 4]), mstype.int32)
output = x.greater(y)
expected = np.array([False, True, False])
np.testing.assert_array_equal(output.asnumpy(), expected)

View File

@ -0,0 +1,94 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore import Tensor
import mindspore.context as context
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_igamma_functional_api_modes(mode):
"""
Feature: Test igamma functional api.
Description: Test igamma functional api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="CPU")
a = Tensor([2.0, 4.0, 6.0, 8.0], mstype.float32)
x = Tensor([2.0, 3.0, 4.0, 5.0], mstype.float32)
output = F.igamma(a, x)
expected = np.array([0.593994, 0.35276785, 0.21486944, 0.13337152])
np.testing.assert_array_almost_equal(output.asnumpy(), expected, decimal=4)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_igamma_tensor_api_modes(mode):
"""
Feature: Test igamma tensor api.
Description: Test igamma tensor api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="CPU")
a = Tensor([2.0, 4.0, 6.0, 8.0], mstype.float32)
x = Tensor([2.0, 3.0, 4.0, 5.0], mstype.float32)
output = a.igamma(x)
expected = np.array([0.593994, 0.35276785, 0.21486944, 0.13337152])
np.testing.assert_array_almost_equal(output.asnumpy(), expected, decimal=4)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_igammac_functional_api_modes(mode):
"""
Feature: Test igamma functional api.
Description: Test igamma functional api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="CPU")
a = Tensor([2.0, 4.0, 6.0, 8.0], mstype.float32)
x = Tensor([2.0, 3.0, 4.0, 5.0], mstype.float32)
output = F.igammac(a, x)
expected = np.array([0.40600586, 0.6472318, 0.7851304, 0.8666283])
np.testing.assert_array_almost_equal(output.asnumpy(), expected, decimal=4)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_igammac_tensor_api_modes(mode):
"""
Feature: Test igamma tensor api.
Description: Test igamma tensor api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="CPU")
a = Tensor([2.0, 4.0, 6.0, 8.0], mstype.float32)
x = Tensor([2.0, 3.0, 4.0, 5.0], mstype.float32)
output = a.igammac(x)
expected = np.array([0.40600586, 0.6472318, 0.7851304, 0.8666283])
np.testing.assert_array_almost_equal(output.asnumpy(), expected, decimal=4)

View File

@ -22,6 +22,7 @@ import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore.ops.functional import vmap
from mindspore.common import dtype as mstype
class NetIndexAdd(nn.Cell):
@ -398,3 +399,23 @@ def test_index_add_vmap_cpu():
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
vmap_case()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_index_add_tensor_api_modes(mode):
"""
Feature: Test index_add tensor api.
Description: Test index_add tensor api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="CPU")
x = Parameter(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float32), name="name_x")
index = Tensor([0, 2], mstype.int32)
source = Tensor([[0.5, 1.0], [1.0, 1.5], [2.0, 2.5]], mstype.float32)
dim = 1
output = x.index_add(dim, index, source)
expected = np.array([[1.5, 2., 4.], [5., 5., 7.5], [9., 8., 11.5]], np.float32)
np.testing.assert_array_equal(output.asnumpy(), expected)

View File

@ -1,99 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import torch
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations.other_ops as P
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.api import jit
class BartlettWindowNet(nn.Cell):
def __init__(self, periodic=True, dtype=mstype.float32):
super(BartlettWindowNet, self).__init__()
self.bartlettwindow = P.BartlettWindow(periodic=periodic, dtype=dtype)
@jit
def construct(self, input_x):
return self.bartlettwindow(input_x)
def get_dtype(dtype="float16"):
if dtype == "float16":
nptype = np.float16
msptype = mstype.float16
pttype = torch.float32
elif dtype == "float32":
nptype = np.float32
msptype = mstype.float32
pttype = torch.float32
elif dtype == "float64":
nptype = np.float64
msptype = mstype.float64
pttype = torch.float64
else:
print("The attr 'dtype' must in [float16, float32, float64]")
return nptype, msptype, pttype
def bartlett_window(periodic, dtype, loss):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
nptype, msptype, pttype = get_dtype(dtype)
input_x_np = np.array(200, dtype=np.int32)
input_x_ms = Tensor(input_x_np)
input_x_torch = torch.tensor(input_x_np)
bartlett_window_net = BartlettWindowNet(periodic, msptype)
bartlett_window_output = bartlett_window_net(input_x_ms)
bartlett_window_expect = torch.bartlett_window(input_x_torch, periodic=periodic, dtype=pttype)
assert np.allclose(bartlett_window_output.asnumpy(), bartlett_window_expect.numpy().astype(nptype), loss, loss)
def bartlett_window_pynative(periodic, dtype, loss):
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
nptype, msptype, pttype = get_dtype(dtype)
input_x_np = np.array(200, dtype=np.int64)
input_x_ms = Tensor(input_x_np)
input_x_torch = torch.tensor(input_x_np)
bartlett_window_net = BartlettWindowNet(periodic, msptype)
bartlett_window_output = bartlett_window_net(input_x_ms)
bartlett_window_expect = torch.bartlett_window(input_x_torch, periodic=periodic, dtype=pttype)
assert np.allclose(bartlett_window_output.asnumpy(), bartlett_window_expect.numpy().astype(nptype), loss, loss)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_bartlett_window_graph_int32_true_float32():
"""
Feature: ALL To ALL
Description: test cases for BartlettWindow
Expectation: the result match to torch
"""
bartlett_window(periodic=True, dtype="float32", loss=1.0e-4)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_bartlett_window_pynative_int64_false_float64():
"""
Feature: ALL To ALL
Description: test cases for BartlettWindow
Expectation: the result match to torch
"""
bartlett_window_pynative(periodic=False, dtype="float64", loss=1.0e-5)

View File

@ -0,0 +1,98 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore import Tensor
import mindspore.context as context
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_greater_equal_functional_api_modes(mode):
"""
Feature: Test greater_equal functional api.
Description: Test greater_equal functional api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="GPU")
x = Tensor(np.array([1, 2, 3]), mstype.int32)
y = Tensor(np.array([1, 1, 4]), mstype.int32)
output = F.greater_equal(x, y)
expected = np.array([True, True, False])
np.testing.assert_array_equal(output.asnumpy(), expected)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_greater_equal_tensor_api_modes(mode):
"""
Feature: Test greater_equal tensor api.
Description: Test greater_equal tensor api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="GPU")
x = Tensor(np.array([1, 2, 3]), mstype.int32)
y = Tensor(np.array([1, 1, 4]), mstype.int32)
output = x.greater_equal(y)
expected = np.array([True, True, False])
np.testing.assert_array_equal(output.asnumpy(), expected)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_greater_functional_api_modes(mode):
"""
Feature: Test greater functional api.
Description: Test greater functional api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="GPU")
x = Tensor(np.array([1, 2, 3]), mstype.int32)
y = Tensor(np.array([1, 1, 4]), mstype.int32)
output = F.greater(x, y)
expected = np.array([False, True, False])
np.testing.assert_array_equal(output.asnumpy(), expected)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_greater_tensor_api_modes(mode):
"""
Feature: Test greater tensor api.
Description: Test greater tensor api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="GPU")
x = Tensor(np.array([1, 2, 3]), mstype.int32)
y = Tensor(np.array([1, 1, 4]), mstype.int32)
output = x.greater(y)
expected = np.array([False, True, False])
np.testing.assert_array_equal(output.asnumpy(), expected)

View File

@ -21,6 +21,8 @@ from mindspore.ops import composite as C
from mindspore.ops.operations.math_ops import Igamma, Igammac
from mindspore.ops.operations._grad_ops import IgammaGradA
from mindspore.nn import Cell
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
class IgammaNet(Cell):
@ -197,3 +199,75 @@ def test_igammagrada_fp64():
output_ms = net(Tensor(a_np), Tensor(x_np))
expect_output = np.array([[0, 0], [0, 0]])
assert np.allclose(output_ms.asnumpy(), expect_output, 1e-5, 1e-5)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_igamma_functional_api_modes(mode):
"""
Feature: Test igamma functional api.
Description: Test igamma functional api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="GPU")
a = Tensor([2.0, 4.0, 6.0, 8.0], mstype.float32)
x = Tensor([2.0, 3.0, 4.0, 5.0], mstype.float32)
output = F.igamma(a, x)
expected = np.array([0.593994, 0.35276785, 0.21486944, 0.13337152])
np.testing.assert_array_almost_equal(output.asnumpy(), expected, decimal=4)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_igamma_tensor_api_modes(mode):
"""
Feature: Test igamma tensor api.
Description: Test igamma tensor api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="GPU")
a = Tensor([2.0, 4.0, 6.0, 8.0], mstype.float32)
x = Tensor([2.0, 3.0, 4.0, 5.0], mstype.float32)
output = a.igamma(x)
expected = np.array([0.593994, 0.35276785, 0.21486944, 0.13337152])
np.testing.assert_array_almost_equal(output.asnumpy(), expected, decimal=4)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_igammac_functional_api_modes(mode):
"""
Feature: Test igamma functional api.
Description: Test igamma functional api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="GPU")
a = Tensor([2.0, 4.0, 6.0, 8.0], mstype.float32)
x = Tensor([2.0, 3.0, 4.0, 5.0], mstype.float32)
output = F.igammac(a, x)
expected = np.array([0.40600586, 0.6472318, 0.7851304, 0.8666283])
np.testing.assert_array_almost_equal(output.asnumpy(), expected, decimal=4)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_igammac_tensor_api_modes(mode):
"""
Feature: Test igamma tensor api.
Description: Test igamma tensor api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="GPU")
a = Tensor([2.0, 4.0, 6.0, 8.0], mstype.float32)
x = Tensor([2.0, 3.0, 4.0, 5.0], mstype.float32)
output = a.igammac(x)
expected = np.array([0.40600586, 0.6472318, 0.7851304, 0.8666283])
np.testing.assert_array_almost_equal(output.asnumpy(), expected, decimal=4)

View File

@ -22,6 +22,7 @@ import mindspore.nn as nn
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
class NetIndexAdd(nn.Cell):
@ -419,3 +420,23 @@ def test_index_add_dynamic():
net.set_inputs(Tensor(idx), y_dyn)
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all()
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_index_add_tensor_api_modes(mode):
"""
Feature: Test index_add tensor api.
Description: Test index_add tensor api for Graph and PyNative modes.
Expectation: The result match to the expect value.
"""
context.set_context(mode=mode, device_target="GPU")
x = Parameter(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float32), name="name_x")
index = Tensor([0, 2], mstype.int32)
source = Tensor([[0.5, 1.0], [1.0, 1.5], [2.0, 2.5]], mstype.float32)
dim = 1
output = x.index_add(dim, index, source)
expected = np.array([[1.5, 2., 4.], [5., 5., 7.5], [9., 8., 11.5]], np.float32)
np.testing.assert_array_equal(output.asnumpy(), expected)