!41473 [ST][MS][OPS] bmm Functional API & ST - addmv, asinh, atan, atanh & bmm Tesnsor APIs and STs.

Merge pull request !41473 from alashkari/tensor-apis-sept04
This commit is contained in:
i-robot 2022-09-13 03:12:28 +00:00 committed by Gitee
commit 1f3b82199f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
25 changed files with 991 additions and 23 deletions

View File

@ -125,7 +125,14 @@ functional算子是经过初始化后的Primitive可以直接作为函数使
mindspore.ops.cdist
数学运算函数
----------------
^^^^^^^^^^^^^^^^^
.. mscnplatformautosummary::
:toctree: ops
:nosignatures:
:template: classtemplate.rst
mindspore.ops.bmm
逐元素运算
^^^^^^^^^^^^^

View File

@ -40,7 +40,13 @@ mindspore.Tensor
mindspore.Tensor.soft_shrink
数学运算方法
----------------
^^^^^^^^^^^^^^^
.. mscnplatformautosummary::
:toctree: Tensor
:nosignatures:
mindspore.Tensor.bmm
逐元素运算
^^^^^^^^^^^^^
@ -55,6 +61,10 @@ mindspore.Tensor
mindspore.Tensor.addcdiv
mindspore.Tensor.addcmul
mindspore.Tensor.asin
mindspore.Tensor.addmv
mindspore.Tensor.asinh
mindspore.Tensor.atan
mindspore.Tensor.atanh
mindspore.Tensor.atan2
mindspore.Tensor.bernoulli
mindspore.Tensor.bitwise_and

View File

@ -38,7 +38,13 @@ Activation Function
mindspore.Tensor.soft_shrink
Mathematical Methods
--------------------
^^^^^^^^^^^^^^^^^^^^
.. msplatformautosummary::
:toctree: Tensor
:nosignatures:
mindspore.Tensor.bmm
Element-wise Methods
^^^^^^^^^^^^^^^^^^^^
@ -53,6 +59,10 @@ Element-wise Methods
mindspore.Tensor.addcmul
mindspore.Tensor.addr
mindspore.Tensor.asin
mindspore.Tensor.addmv
mindspore.Tensor.asinh
mindspore.Tensor.atan
mindspore.Tensor.atanh
mindspore.Tensor.atan2
mindspore.Tensor.bernoulli
mindspore.Tensor.bitwise_and

View File

@ -127,7 +127,14 @@ Distance Functions
mindspore.ops.cdist
Mathematical Functions
----------------------
^^^^^^^^^^^^^^^^^^^^^^
.. msplatformautosummary::
:toctree: ops
:nosignatures:
:template: classtemplate.rst
mindspore.ops.bmm
Element-by-Element Operations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@ -297,6 +297,11 @@ BuiltInTypeMap &GetMethodMap() {
{"addr", std::string("addr")}, // addr()
{"add", std::string("add")}, // P.Add()
{"asin", std::string("asin")}, // asin()
{"addmv", std::string("addmv")}, // addmv()
{"asinh", std::string("asinh")}, // asinh()
{"atan", std::string("atan")}, // atan()
{"atanh", std::string("atanh")}, // atanh()
{"bmm", std::string("bmm")}, // bmm()
}},
{kObjectTypeRowTensorType,
{

View File

@ -3076,3 +3076,38 @@ def addr(x, vec1, vec2, beta=1, alpha=1):
Computes the outer-product of `vec1` and `vec2` and adds it to the vec1rix `x`.
"""
return F.addr(x, vec1, vec2, beta=1, alpha=1)
def addmv(x, mat, vec, beta=1, alpha=1):
r"""
Multiplies matrix `mat` and vector `vec`. The vector `x` is added to the final result.
"""
return F.addmv(x, mat, vec, beta, alpha)
def asinh(x):
r"""
Computes inverse hyperbolic sine of the input element-wise.
"""
return F.asinh(x)
def atan(x):
r"""
Computes inverse tangent of the input element-wise.
"""
return F.atan(x)
def atanh(x):
r"""
Computes inverse hyperbolic tangent of the input element-wise.
"""
return F.atanh(x)
def bmm(input_x, mat2):
r"""
Computes matrix multiplication between two tensors by batch.
"""
return F.bmm(input_x, mat2)

View File

@ -695,25 +695,25 @@ class Tensor(Tensor_):
r"""
Executes the outer-product of `vec1` and `vec2` and adds it to the input tensor.
If `vec1` is a vector of size :vec1:`N` and `vec2` is a vector of size :vec1:`M`, then `x` must be
broadcastable with a vec1rix of size :vec1:`(N, M)` and `out` will be a vec1rix of size :vec1:`(N, M)`.
If `vec1` is a vector of size `N` and `vec2` is a vector of size `M`, then `x` must be
broadcastable with a matrix of size `(N, M)` and `out` will be a matrix of size `(N, M)`.
The optional values `beta` and `alpha` are the scale factors on the outer product between `vec1` and `vec2`
and the added vec1rix `x` respectively. If `beta` is 0, then `x` will be ignored.
and the added matrix `x` respectively. If `beta` is 0, then `x` will be ignored.
.. math::
output = β x + α (vec1 vec2)
Args:
vec1 (Tensor): The first tensor to be multiplied. The shape of the tensor is :vec1:`(N,)`.
vec2 (Tensor): The second tensor to be multiplied. The shape of the tensor is :vec1:`(M,)`.
vec1 (Tensor): The first tensor to be multiplied. The shape of the tensor is `(N,)`.
vec2 (Tensor): The second tensor to be multiplied. The shape of the tensor is `(M,)`.
beta (scalar[int, float, bool], optional): Multiplier for `x` (β). The `beta` must be int or
float or bool, Default: 1.
alpha (scalar[int, float, bool], optional): Multiplier for `vec1` @ `vec2` (α). The `alpha` must
be int or float or bool, Default: 1.
Outputs:
Tensor, the shape of the output tensor is :vec1:`(N, M)`, has the same dtype as `x`.
Tensor, the shape of the output tensor is `(N, M)`, has the same dtype as `x`.
Raises:
TypeError: If `x`, `vec1`, `vec2` is not a Tensor.
@ -5346,6 +5346,169 @@ class Tensor(Tensor_):
validator.check_axis_in_range(axis, self.ndim)
return tensor_operator_registry.get('median')(global_median, axis, keep_dims)(self)
def addmv(self, mat, vec, beta=1, alpha=1):
r"""
Multiplies matrix `mat` and vector `vec`. Input vector is added to the final result.
If mat is a :math:`(N, M)` tensor, vec is a 1-D tensor of size :math:`M`, then `x` must be broadcastable
with a 1-D tensor of size :math:`N` and `out` will be 1-D tensor of size :math:`N`.
The optional values `beta` and `alpha` are the matrix-vector product between `mat` and `vec` and the scale
factor for the added tensor `x` respectively. If `beta` is 0, then `x` will be ignored.
.. math::
output = β x + α (mat @ vec)
Args:
mat (Tensor): The first tensor to be multiplied. The shape of the tensor is :math:`(N, M)`.
vec (Tensor): The second tensor to be multiplied. The shape of the tensor is :math:`(M,)`.
beta (scalar[int, float, bool], optional): Multiplier for `x` (β). The `beta` must be int or
float or bool, Default: 1.
alpha (scalar[int, float, bool], optional): Multiplier for `mat` @ `vec` (α). The `alpha` must
be int or float or bool, Default: 1.
Returns:
Tensor, the shape of the output tensor is :math:`(N,)`, has the same dtype as `x`.
Raises:
TypeError: If `mat`, `vec`, `x` is not a Tensor.
TypeError: If input tensor and `x`, `mat`, 'vec' are not the same dtype.
ValueError: If `mat` is not a 2-D Tensor.
If `x`, `vec` is not a 1-D Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([2., 3.]).astype(np.float32))
>>> mat = Tensor(np.array([[2., 5., 3.], [4., 2., 2.]]).astype(np.float32))
>>> vec = Tensor(np.array([3., 2., 4.]).astype(np.float32))
>>> output = x.addmv(mat, vec)
>>> print(output)
[30. 27.]
"""
self._init_check()
return tensor_operator_registry.get('addmv')(self, mat, vec, beta=1, alpha=1)
def asinh(self):
r"""
Computes inverse hyperbolic sine of the input element-wise.
.. math::
out_i = \sinh^{-1}(input_i)
Returns:
Tensor, has the same shape and type as input.
Raises:
TypeError: If input is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([-5.0, 1.5, 3.0, 100.0]), mindspore.float32)
>>> output = x.asinh()
>>> print(output)
[-2.3124382 1.1947632 1.8184465 5.298342 ]
"""
self._init_check()
return tensor_operator_registry.get('asinh')(self)
def atan(self):
r"""
Computes the trigonometric inverse tangent of the input element-wise.
.. math::
out_i = tan^{-1}(x_i)
Returns:
A Tensor, has the same type as the input.
Raises:
TypeError: If input is not a Tensor.
TypeError: If input tensor dtype is not float16 or float32.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([1.0, 0.0]), mindspore.float32)
>>> output = x.atan()
>>> print(output)
[0.7853982 0.0]
"""
self._init_check()
return tensor_operator_registry.get('atan')(self)
def atanh(self):
r"""
Computes inverse hyperbolic tangent of the input element-wise.
.. math::
out_i = \tanh^{-1}(x_{i})
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Returns:
A Tensor, has the same type as the input.
Raises:
TypeError: If input is not a Tensor.
TypeError: If input tensor dtype is not float16 or float32.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> x = Tensor(np.array([0, -0.5]), mindspore.float32)
>>> output = ops.atanh(x)
>>> print(output)
[ 0. -0.54930615]
"""
self._init_check()
return tensor_operator_registry.get('atanh')(self)
def bmm(self, mat2):
r"""
Computes matrix multiplication between two tensors by batch.
.. math::
\text{output}[..., :, :] = \text{matrix}(input_x[..., :, :]) * \text{matrix}(mat2[..., :, :])
The first input tensor must be not less than `3` and the second input must be not less than `2`.
Args:
mat2 (Tensor) - The tensor to be multiplied. The shape of the tensor is :math:`(*B, C, M)`.
Outputs:
Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
Raises:
ValueError: If length of shape of `input_x` is not equal to length of shape of `mat2` or
length of shape of `input_x` is less than 3.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.ones(shape=[2, 4, 1, 3]), mindspore.float32)
>>> mat2 = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
>>> output = input_x.bmm(mat2)
>>> print(output)
[[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]
[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]]
"""
self._init_check()
return tensor_operator_registry.get('bmm')(self, mat2)
class RowTensor(RowTensor_):
"""

View File

@ -266,7 +266,8 @@ from .math_func import (
kron,
rot90,
remainder,
iou
iou,
bmm
)
from .nn_func import (
adaptive_avg_pool1d,

View File

@ -5179,6 +5179,53 @@ def matmul(x1, x2):
return reshape_op(res, shape_out)
def bmm(input_x, mat2):
"""
Computes matrix multiplication between two tensors by batch.
.. math::
\text{output}[..., :, :] = \text{matrix}(input_x[..., :, :]) * \text{matrix}(mat2[..., :, :])
The first input tensor must be not less than `3` and the second input must be not less than `2`.
Args:
input_x (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(*B, N, C)`,
where :math:`*B` represents the batch size which can be multidimensional, :math:`N` and :math:`C` are the
size of the last two dimensions.
mat2 (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(*B, C, M)`.
Outputs:
Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
Raises:
ValueError: If length of shape of `input_x` is not equal to length of shape of `y` or
length of shape of `input_x` is less than 3.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.ones(shape=[2, 4, 1, 3]), mindspore.float32)
>>> mat2 = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
>>> output = ops.bmm(input_x, mat2)
>>> print(output)
[[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]
[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]]
"""
if not (isinstance(input_x, Tensor) and isinstance(mat2, Tensor)):
raise TypeError("For bmm op, inputs input_x and mat2 must be all tensors.")
bmm_op = _get_cache_prim(P.BatchMatMul)()
return bmm_op(input_x, mat2)
def baddbmm(x, batch1, batch2, beta=1, alpha=1):
r"""
Performs a batch matrix-matrix product of matrices in batch1 and batch2. input is added to the final result.
@ -5959,6 +6006,7 @@ __all__ = [
'rot90',
'remainder',
'accumulate_n',
'iou'
'iou',
'bmm'
]
__all__.sort()

View File

@ -419,6 +419,11 @@ tensor_operator_registry.register('sigmoid', P.Sigmoid)
tensor_operator_registry.register('median', Median)
tensor_operator_registry.register('tanh', tanh)
tensor_operator_registry.register('exp', P.Exp)
tensor_operator_registry.register('addmv', addmv)
tensor_operator_registry.register('asinh', asinh)
tensor_operator_registry.register('atan', atan)
tensor_operator_registry.register('atanh', atanh)
tensor_operator_registry.register('bmm', bmm)
# ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)

View File

@ -0,0 +1,51 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
# all cases tested against dchip
def test_addmv_forward_tensor_api(nptype):
"""
Feature: test addmv forward tensor api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.array([2., 3.]).astype(nptype))
mat = Tensor(np.array([[2., 5., 3.], [4., 2., 2.]]).astype(nptype))
vec = Tensor(np.array([3., 2., 4.]).astype(nptype))
output = x.addmv(mat, vec)
expected = np.array([30., 27.]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_addmv_forward_float32_tensor_api():
"""
Feature: test addmv forward tensor api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_addmv_forward_tensor_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
test_addmv_forward_tensor_api(np.float32)

View File

@ -0,0 +1,49 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
# all cases tested against dchip
def test_asinh_forward_tensor_api(nptype):
"""
Feature: test asinh forward tensor api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.array([-5.0, 1.5, 3.0, 100.0]).astype(nptype))
output = x.asinh()
expected = np.array([-2.3124382, 1.1947632, 1.8184465, 5.298342]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_asinh_forward_float32_tensor_api():
"""
Feature: test asinh forward tensor api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_asinh_forward_tensor_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
test_asinh_forward_tensor_api(np.float32)

View File

@ -0,0 +1,49 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
# all cases tested against dchip
def test_atan_forward_tensor_api(nptype):
"""
Feature: test atan forward tensor api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.array([1.0, 0.0]).astype(nptype))
output = x.atan()
expected = np.array([0.7853982, 0.0]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_atan_forward_float32_tensor_api():
"""
Feature: test atan forward tensor api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_atan_forward_tensor_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
test_atan_forward_tensor_api(np.float32)

View File

@ -0,0 +1,49 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
# all cases tested against dchip
def test_atanh_forward_tensor_api(nptype):
"""
Feature: test atanh forward tensor api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.array([0, -0.5]).astype(nptype))
output = x.atanh()
expected = np.array([0.0, -0.54930615]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_atanh_forward_float32_tensor_api():
"""
Feature: test atanh forward tensor api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_atanh_forward_tensor_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
test_atanh_forward_tensor_api(np.float32)

View File

@ -0,0 +1,80 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import functional as F
# all cases tested against dchip
def test_bmm_forward_tensor_api(nptype):
"""
Feature: test bmm forward tensor api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.ones(shape=[2, 4, 1, 3]).astype(nptype))
y = Tensor(np.ones(shape=[2, 4, 3, 4]).astype(nptype))
output = x.bmm(y)
expected = 3 * np.ones(shape=[2, 4, 1, 4]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_bmm_forward_float32_tensor_api():
"""
Feature: test bmm forward tensor api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_bmm_forward_tensor_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
test_bmm_forward_tensor_api(np.float32)
def test_bmm_forward_functional_api(nptype):
"""
Feature: test bmm forward functional api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.ones(shape=[2, 4, 1, 3]).astype(nptype))
y = Tensor(np.ones(shape=[2, 4, 3, 4]).astype(nptype))
output = F.bmm(x, y)
expected = 3 * np.ones(shape=[2, 4, 1, 4]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_bmm_forward_float32_functional_api():
"""
Feature: test bmm forward functional api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_bmm_forward_functional_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
test_bmm_forward_functional_api(np.float32)

View File

@ -0,0 +1,52 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
def test_addmv_forward_tensor_api(nptype):
"""
Feature: test addmv forward tensor api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.array([2., 3.]).astype(nptype))
mat = Tensor(np.array([[2., 5., 3.], [4., 2., 2.]]).astype(nptype))
vec = Tensor(np.array([3., 2., 4.]).astype(nptype))
output = x.addmv(mat, vec)
expected = np.array([30., 27.]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_addmv_forward_float32_tensor_api():
"""
Feature: test addmv forward tensor api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
test_addmv_forward_tensor_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
test_addmv_forward_tensor_api(np.float32)
if __name__ == '__main__':
test_addmv_forward_float32_tensor_api()

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -50,3 +50,34 @@ def test_asinh(dtype):
print(output)
expect = np.arcsinh(np_array)
assert np.allclose(output.asnumpy(), expect)
def test_asinh_forward_tensor_api(nptype):
"""
Feature: test asinh forward tensor api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.array([-5.0, 1.5, 3.0, 100.0]).astype(nptype))
output = x.asinh()
expected = np.array([-2.3124382, 1.1947632, 1.8184465, 5.298342]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_asinh_forward_float32_tensor_api():
"""
Feature: test asinh forward tensor api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
test_asinh_forward_tensor_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
test_asinh_forward_tensor_api(np.float32)
if __name__ == '__main__':
test_asinh_forward_float32_tensor_api()

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -50,3 +50,34 @@ def test_atan(dtype):
print(output)
expect = np.arctan(np_array)
assert np.allclose(output.asnumpy(), expect)
def test_atan_forward_tensor_api(nptype):
"""
Feature: test atan forward tensor api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.array([1.0, 0.0]).astype(nptype))
output = x.atan()
expected = np.array([0.7853982, 0.0]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_atan_forward_float32_tensor_api():
"""
Feature: test atan forward tensor api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
test_atan_forward_tensor_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
test_atan_forward_tensor_api(np.float32)
if __name__ == '__main__':
test_atan_forward_float32_tensor_api()

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -50,3 +50,34 @@ def test_atanh(dtype):
print(output)
expect = np.arctanh(np_array)
assert np.allclose(output.asnumpy(), expect)
def test_atanh_forward_tensor_api(nptype):
"""
Feature: test atanh forward tensor api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.array([0, -0.5]).astype(nptype))
output = x.atanh()
expected = np.array([0.0, -0.54930615]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_atanh_forward_float32_tensor_api():
"""
Feature: test atanh forward tensor api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
test_atanh_forward_tensor_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
test_atanh_forward_tensor_api(np.float32)
if __name__ == '__main__':
test_atanh_forward_float32_tensor_api()

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -21,6 +21,7 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
class BatchMatMulNet(nn.Cell):
@ -166,3 +167,64 @@ def test_4d_transpose_ab():
[[5860., 6148., 6436., 6724.],
[6043., 6340., 6637., 6934.]]]], np.float16)
judge_result_correct(output.asnumpy(), expect)
def test_bmm_forward_tensor_api(nptype):
"""
Feature: test bmm forward tensor api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.ones(shape=[2, 4, 1, 3]).astype(nptype))
y = Tensor(np.ones(shape=[2, 4, 3, 4]).astype(nptype))
output = x.bmm(y)
expected = 3 * np.ones(shape=[2, 4, 1, 4]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_bmm_forward_float32_tensor_api():
"""
Feature: test bmm forward tensor api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
test_bmm_forward_tensor_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
test_bmm_forward_tensor_api(np.float32)
def test_bmm_forward_functional_api(nptype):
"""
Feature: test bmm forward functional api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.ones(shape=[2, 4, 1, 3]).astype(nptype))
y = Tensor(np.ones(shape=[2, 4, 3, 4]).astype(nptype))
output = F.bmm(x, y)
expected = 3 * np.ones(shape=[2, 4, 1, 4]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_bmm_forward_float32_functional_api():
"""
Feature: test bmm forward functional api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
test_bmm_forward_functional_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
test_bmm_forward_functional_api(np.float32)
if __name__ == '__main__':
test_bmm_forward_float32_tensor_api()
test_bmm_forward_float32_functional_api()

View File

@ -0,0 +1,48 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
def test_addmv_forward_tensor_api(nptype):
"""
Feature: test addmv forward tensor api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.array([2., 3.]).astype(nptype))
mat = Tensor(np.array([[2., 5., 3.], [4., 2., 2.]]).astype(nptype))
vec = Tensor(np.array([3., 2., 4.]).astype(nptype))
output = x.addmv(mat, vec)
expected = np.array([30., 27.]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_addmv_forward_float32_tensor_api():
"""
Feature: test addmv forward tensor api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_addmv_forward_tensor_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
test_addmv_forward_tensor_api(np.float32)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -39,3 +39,30 @@ def test_asinh_fp16():
output_ms = P.Asinh()(Tensor(x_np))
output_np = np.arcsinh(x_np.astype(np.float32)).astype(np.float16)
assert np.allclose(output_ms.asnumpy(), output_np, 1e-3, 1e-3)
def test_asinh_forward_tensor_api(nptype):
"""
Feature: test asinh forward tensor api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.array([-5.0, 1.5, 3.0, 100.0]).astype(nptype))
output = x.asinh()
expected = np.array([-2.3124382, 1.1947632, 1.8184465, 5.298342]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_asinh_forward_float32_tensor_api():
"""
Feature: test asinh forward tensor api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_asinh_forward_tensor_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
test_asinh_forward_tensor_api(np.float32)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -39,3 +39,30 @@ def test_atan_fp16():
output_ms = P.Atan()(Tensor(x_np))
output_np = np.arctan(x_np.astype(np.float32)).astype(np.float16)
assert np.allclose(output_ms.asnumpy(), output_np, 1e-3, 1e-3)
def test_atan_forward_tensor_api(nptype):
"""
Feature: test atan forward tensor api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.array([1.0, 0.0]).astype(nptype))
output = x.atan()
expected = np.array([0.7853982, 0.0]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_atan_forward_float32_tensor_api():
"""
Feature: test atan forward tensor api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_atan_forward_tensor_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
test_atan_forward_tensor_api(np.float32)

View File

@ -109,3 +109,30 @@ def test_atanh_complex128():
output_ms = P.Atanh()(Tensor(x_np))
expect = atanh(x_np)
assert np.allclose(output_ms.asnumpy(), expect)
def test_atanh_forward_tensor_api(nptype):
"""
Feature: test atanh forward tensor api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.array([0, -0.5]).astype(nptype))
output = x.atanh()
expected = np.array([0.0, -0.54930615]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_atanh_forward_float32_tensor_api():
"""
Feature: test atanh forward tensor api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_atanh_forward_tensor_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
test_atanh_forward_tensor_api(np.float32)

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -22,6 +22,7 @@ from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops import functional as F
class BatchMatMulNet(nn.Cell):
@ -32,6 +33,7 @@ class BatchMatMulNet(nn.Cell):
def construct(self, x, y):
return self.batch_matmul(x, y)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -153,7 +155,12 @@ def test_4d_transpose_ab():
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_4D_fp16():
def test_4d_fp_16():
"""
Feature: test BatchMatMul op.
Description: test BatchMatMul 4d input dtype float16.
Expectation: the result match with expected result.
"""
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float16)
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float16)
@ -172,9 +179,9 @@ def test_4D_fp16():
assert (output.asnumpy() == expect).all()
class BatchMatMul_d(nn.Cell):
class BatchMatMulDynamic(nn.Cell):
def __init__(self, transpose_a=False, transpose_b=False):
super(BatchMatMul_d, self).__init__()
super(BatchMatMulDynamic, self).__init__()
self.batch_matmul = P.BatchMatMul(transpose_a, transpose_b)
self.test_dynamic = inner.GpuConvertToDynamicShape()
@ -190,7 +197,7 @@ class BatchMatMul_d(nn.Cell):
def test_batchmatmul_dynamic():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = BatchMatMul_d()
net = BatchMatMulDynamic()
x1 = np.arange(8).reshape(2, 2, 2).astype(np.float32)
y1 = np.arange(28).reshape(2, 2, 7).astype(np.float32)
@ -205,3 +212,59 @@ def test_batchmatmul_dynamic():
output2 = net(Tensor(x2), Tensor(y2))
expect2 = np.matmul(x2, y2)
assert (output2.asnumpy() == expect2).all()
def test_bmm_forward_tensor_api(nptype):
"""
Feature: test bmm forward tensor api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.ones(shape=[2, 4, 1, 3]).astype(nptype))
y = Tensor(np.ones(shape=[2, 4, 3, 4]).astype(nptype))
output = x.bmm(y)
expected = 3 * np.ones(shape=[2, 4, 1, 4]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_bmm_forward_float32_tensor_api():
"""
Feature: test bmm forward tensor api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_bmm_forward_tensor_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
test_bmm_forward_tensor_api(np.float32)
def test_bmm_forward_functional_api(nptype):
"""
Feature: test bmm forward functional api for given input dtype.
Description: test inputs for given input dtype.
Expectation: the result match with expected result.
"""
x = Tensor(np.ones(shape=[2, 4, 1, 3]).astype(nptype))
y = Tensor(np.ones(shape=[2, 4, 3, 4]).astype(nptype))
output = F.bmm(x, y)
expected = 3 * np.ones(shape=[2, 4, 1, 4]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_bmm_forward_float32_functional_api():
"""
Feature: test bmm forward functional api.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_bmm_forward_functional_api(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
test_bmm_forward_functional_api(np.float32)