forked from mindspore-Ecosystem/mindspore
!41473 [ST][MS][OPS] bmm Functional API & ST - addmv, asinh, atan, atanh & bmm Tesnsor APIs and STs.
Merge pull request !41473 from alashkari/tensor-apis-sept04
This commit is contained in:
commit
1f3b82199f
|
@ -125,7 +125,14 @@ functional算子是经过初始化后的Primitive,可以直接作为函数使
|
|||
mindspore.ops.cdist
|
||||
|
||||
数学运算函数
|
||||
----------------
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. mscnplatformautosummary::
|
||||
:toctree: ops
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.ops.bmm
|
||||
|
||||
逐元素运算
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
@ -40,7 +40,13 @@ mindspore.Tensor
|
|||
mindspore.Tensor.soft_shrink
|
||||
|
||||
数学运算方法
|
||||
----------------
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. mscnplatformautosummary::
|
||||
:toctree: Tensor
|
||||
:nosignatures:
|
||||
|
||||
mindspore.Tensor.bmm
|
||||
|
||||
逐元素运算
|
||||
^^^^^^^^^^^^^
|
||||
|
@ -55,6 +61,10 @@ mindspore.Tensor
|
|||
mindspore.Tensor.addcdiv
|
||||
mindspore.Tensor.addcmul
|
||||
mindspore.Tensor.asin
|
||||
mindspore.Tensor.addmv
|
||||
mindspore.Tensor.asinh
|
||||
mindspore.Tensor.atan
|
||||
mindspore.Tensor.atanh
|
||||
mindspore.Tensor.atan2
|
||||
mindspore.Tensor.bernoulli
|
||||
mindspore.Tensor.bitwise_and
|
||||
|
|
|
@ -38,7 +38,13 @@ Activation Function
|
|||
mindspore.Tensor.soft_shrink
|
||||
|
||||
Mathematical Methods
|
||||
--------------------
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. msplatformautosummary::
|
||||
:toctree: Tensor
|
||||
:nosignatures:
|
||||
|
||||
mindspore.Tensor.bmm
|
||||
|
||||
Element-wise Methods
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
@ -53,6 +59,10 @@ Element-wise Methods
|
|||
mindspore.Tensor.addcmul
|
||||
mindspore.Tensor.addr
|
||||
mindspore.Tensor.asin
|
||||
mindspore.Tensor.addmv
|
||||
mindspore.Tensor.asinh
|
||||
mindspore.Tensor.atan
|
||||
mindspore.Tensor.atanh
|
||||
mindspore.Tensor.atan2
|
||||
mindspore.Tensor.bernoulli
|
||||
mindspore.Tensor.bitwise_and
|
||||
|
|
|
@ -127,7 +127,14 @@ Distance Functions
|
|||
mindspore.ops.cdist
|
||||
|
||||
Mathematical Functions
|
||||
----------------------
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. msplatformautosummary::
|
||||
:toctree: ops
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.ops.bmm
|
||||
|
||||
Element-by-Element Operations
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -297,6 +297,11 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"addr", std::string("addr")}, // addr()
|
||||
{"add", std::string("add")}, // P.Add()
|
||||
{"asin", std::string("asin")}, // asin()
|
||||
{"addmv", std::string("addmv")}, // addmv()
|
||||
{"asinh", std::string("asinh")}, // asinh()
|
||||
{"atan", std::string("atan")}, // atan()
|
||||
{"atanh", std::string("atanh")}, // atanh()
|
||||
{"bmm", std::string("bmm")}, // bmm()
|
||||
}},
|
||||
{kObjectTypeRowTensorType,
|
||||
{
|
||||
|
|
|
@ -3076,3 +3076,38 @@ def addr(x, vec1, vec2, beta=1, alpha=1):
|
|||
Computes the outer-product of `vec1` and `vec2` and adds it to the vec1rix `x`.
|
||||
"""
|
||||
return F.addr(x, vec1, vec2, beta=1, alpha=1)
|
||||
|
||||
|
||||
def addmv(x, mat, vec, beta=1, alpha=1):
|
||||
r"""
|
||||
Multiplies matrix `mat` and vector `vec`. The vector `x` is added to the final result.
|
||||
"""
|
||||
return F.addmv(x, mat, vec, beta, alpha)
|
||||
|
||||
|
||||
def asinh(x):
|
||||
r"""
|
||||
Computes inverse hyperbolic sine of the input element-wise.
|
||||
"""
|
||||
return F.asinh(x)
|
||||
|
||||
|
||||
def atan(x):
|
||||
r"""
|
||||
Computes inverse tangent of the input element-wise.
|
||||
"""
|
||||
return F.atan(x)
|
||||
|
||||
|
||||
def atanh(x):
|
||||
r"""
|
||||
Computes inverse hyperbolic tangent of the input element-wise.
|
||||
"""
|
||||
return F.atanh(x)
|
||||
|
||||
|
||||
def bmm(input_x, mat2):
|
||||
r"""
|
||||
Computes matrix multiplication between two tensors by batch.
|
||||
"""
|
||||
return F.bmm(input_x, mat2)
|
||||
|
|
|
@ -695,25 +695,25 @@ class Tensor(Tensor_):
|
|||
r"""
|
||||
Executes the outer-product of `vec1` and `vec2` and adds it to the input tensor.
|
||||
|
||||
If `vec1` is a vector of size :vec1:`N` and `vec2` is a vector of size :vec1:`M`, then `x` must be
|
||||
broadcastable with a vec1rix of size :vec1:`(N, M)` and `out` will be a vec1rix of size :vec1:`(N, M)`.
|
||||
If `vec1` is a vector of size `N` and `vec2` is a vector of size `M`, then `x` must be
|
||||
broadcastable with a matrix of size `(N, M)` and `out` will be a matrix of size `(N, M)`.
|
||||
|
||||
The optional values `beta` and `alpha` are the scale factors on the outer product between `vec1` and `vec2`
|
||||
and the added vec1rix `x` respectively. If `beta` is 0, then `x` will be ignored.
|
||||
and the added matrix `x` respectively. If `beta` is 0, then `x` will be ignored.
|
||||
|
||||
.. math::
|
||||
output = β x + α (vec1 ⊗ vec2)
|
||||
|
||||
Args:
|
||||
vec1 (Tensor): The first tensor to be multiplied. The shape of the tensor is :vec1:`(N,)`.
|
||||
vec2 (Tensor): The second tensor to be multiplied. The shape of the tensor is :vec1:`(M,)`.
|
||||
vec1 (Tensor): The first tensor to be multiplied. The shape of the tensor is `(N,)`.
|
||||
vec2 (Tensor): The second tensor to be multiplied. The shape of the tensor is `(M,)`.
|
||||
beta (scalar[int, float, bool], optional): Multiplier for `x` (β). The `beta` must be int or
|
||||
float or bool, Default: 1.
|
||||
alpha (scalar[int, float, bool], optional): Multiplier for `vec1` @ `vec2` (α). The `alpha` must
|
||||
be int or float or bool, Default: 1.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of the output tensor is :vec1:`(N, M)`, has the same dtype as `x`.
|
||||
Tensor, the shape of the output tensor is `(N, M)`, has the same dtype as `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x`, `vec1`, `vec2` is not a Tensor.
|
||||
|
@ -5346,6 +5346,169 @@ class Tensor(Tensor_):
|
|||
validator.check_axis_in_range(axis, self.ndim)
|
||||
return tensor_operator_registry.get('median')(global_median, axis, keep_dims)(self)
|
||||
|
||||
def addmv(self, mat, vec, beta=1, alpha=1):
|
||||
r"""
|
||||
Multiplies matrix `mat` and vector `vec`. Input vector is added to the final result.
|
||||
|
||||
If mat is a :math:`(N, M)` tensor, vec is a 1-D tensor of size :math:`M`, then `x` must be broadcastable
|
||||
with a 1-D tensor of size :math:`N` and `out` will be 1-D tensor of size :math:`N`.
|
||||
|
||||
The optional values `beta` and `alpha` are the matrix-vector product between `mat` and `vec` and the scale
|
||||
factor for the added tensor `x` respectively. If `beta` is 0, then `x` will be ignored.
|
||||
|
||||
.. math::
|
||||
output = β x + α (mat @ vec)
|
||||
|
||||
Args:
|
||||
mat (Tensor): The first tensor to be multiplied. The shape of the tensor is :math:`(N, M)`.
|
||||
vec (Tensor): The second tensor to be multiplied. The shape of the tensor is :math:`(M,)`.
|
||||
beta (scalar[int, float, bool], optional): Multiplier for `x` (β). The `beta` must be int or
|
||||
float or bool, Default: 1.
|
||||
alpha (scalar[int, float, bool], optional): Multiplier for `mat` @ `vec` (α). The `alpha` must
|
||||
be int or float or bool, Default: 1.
|
||||
|
||||
Returns:
|
||||
Tensor, the shape of the output tensor is :math:`(N,)`, has the same dtype as `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `mat`, `vec`, `x` is not a Tensor.
|
||||
TypeError: If input tensor and `x`, `mat`, 'vec' are not the same dtype.
|
||||
ValueError: If `mat` is not a 2-D Tensor.
|
||||
If `x`, `vec` is not a 1-D Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([2., 3.]).astype(np.float32))
|
||||
>>> mat = Tensor(np.array([[2., 5., 3.], [4., 2., 2.]]).astype(np.float32))
|
||||
>>> vec = Tensor(np.array([3., 2., 4.]).astype(np.float32))
|
||||
>>> output = x.addmv(mat, vec)
|
||||
>>> print(output)
|
||||
[30. 27.]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('addmv')(self, mat, vec, beta=1, alpha=1)
|
||||
|
||||
def asinh(self):
|
||||
r"""
|
||||
Computes inverse hyperbolic sine of the input element-wise.
|
||||
|
||||
.. math::
|
||||
out_i = \sinh^{-1}(input_i)
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape and type as input.
|
||||
|
||||
Raises:
|
||||
TypeError: If input is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([-5.0, 1.5, 3.0, 100.0]), mindspore.float32)
|
||||
>>> output = x.asinh()
|
||||
>>> print(output)
|
||||
[-2.3124382 1.1947632 1.8184465 5.298342 ]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('asinh')(self)
|
||||
|
||||
def atan(self):
|
||||
r"""
|
||||
Computes the trigonometric inverse tangent of the input element-wise.
|
||||
|
||||
.. math::
|
||||
out_i = tan^{-1}(x_i)
|
||||
|
||||
Returns:
|
||||
A Tensor, has the same type as the input.
|
||||
|
||||
Raises:
|
||||
TypeError: If input is not a Tensor.
|
||||
TypeError: If input tensor dtype is not float16 or float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([1.0, 0.0]), mindspore.float32)
|
||||
>>> output = x.atan()
|
||||
>>> print(output)
|
||||
[0.7853982 0.0]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('atan')(self)
|
||||
|
||||
def atanh(self):
|
||||
r"""
|
||||
Computes inverse hyperbolic tangent of the input element-wise.
|
||||
|
||||
.. math::
|
||||
out_i = \tanh^{-1}(x_{i})
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Returns:
|
||||
A Tensor, has the same type as the input.
|
||||
|
||||
Raises:
|
||||
TypeError: If input is not a Tensor.
|
||||
TypeError: If input tensor dtype is not float16 or float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([0, -0.5]), mindspore.float32)
|
||||
>>> output = ops.atanh(x)
|
||||
>>> print(output)
|
||||
[ 0. -0.54930615]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('atanh')(self)
|
||||
|
||||
def bmm(self, mat2):
|
||||
r"""
|
||||
Computes matrix multiplication between two tensors by batch.
|
||||
|
||||
.. math::
|
||||
\text{output}[..., :, :] = \text{matrix}(input_x[..., :, :]) * \text{matrix}(mat2[..., :, :])
|
||||
|
||||
The first input tensor must be not less than `3` and the second input must be not less than `2`.
|
||||
|
||||
Args:
|
||||
mat2 (Tensor) - The tensor to be multiplied. The shape of the tensor is :math:`(*B, C, M)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
|
||||
|
||||
Raises:
|
||||
ValueError: If length of shape of `input_x` is not equal to length of shape of `mat2` or
|
||||
length of shape of `input_x` is less than 3.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.ones(shape=[2, 4, 1, 3]), mindspore.float32)
|
||||
>>> mat2 = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
|
||||
>>> output = input_x.bmm(mat2)
|
||||
>>> print(output)
|
||||
[[[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]]
|
||||
[[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]]]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('bmm')(self, mat2)
|
||||
|
||||
|
||||
class RowTensor(RowTensor_):
|
||||
"""
|
||||
|
|
|
@ -266,7 +266,8 @@ from .math_func import (
|
|||
kron,
|
||||
rot90,
|
||||
remainder,
|
||||
iou
|
||||
iou,
|
||||
bmm
|
||||
)
|
||||
from .nn_func import (
|
||||
adaptive_avg_pool1d,
|
||||
|
|
|
@ -5179,6 +5179,53 @@ def matmul(x1, x2):
|
|||
return reshape_op(res, shape_out)
|
||||
|
||||
|
||||
def bmm(input_x, mat2):
|
||||
"""
|
||||
Computes matrix multiplication between two tensors by batch.
|
||||
|
||||
.. math::
|
||||
|
||||
\text{output}[..., :, :] = \text{matrix}(input_x[..., :, :]) * \text{matrix}(mat2[..., :, :])
|
||||
|
||||
The first input tensor must be not less than `3` and the second input must be not less than `2`.
|
||||
|
||||
Args:
|
||||
input_x (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(*B, N, C)`,
|
||||
where :math:`*B` represents the batch size which can be multidimensional, :math:`N` and :math:`C` are the
|
||||
size of the last two dimensions.
|
||||
mat2 (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(*B, C, M)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
|
||||
|
||||
Raises:
|
||||
ValueError: If length of shape of `input_x` is not equal to length of shape of `y` or
|
||||
length of shape of `input_x` is less than 3.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.ones(shape=[2, 4, 1, 3]), mindspore.float32)
|
||||
>>> mat2 = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
|
||||
>>> output = ops.bmm(input_x, mat2)
|
||||
>>> print(output)
|
||||
[[[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]]
|
||||
[[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]]]
|
||||
"""
|
||||
if not (isinstance(input_x, Tensor) and isinstance(mat2, Tensor)):
|
||||
raise TypeError("For bmm op, inputs input_x and mat2 must be all tensors.")
|
||||
|
||||
bmm_op = _get_cache_prim(P.BatchMatMul)()
|
||||
return bmm_op(input_x, mat2)
|
||||
|
||||
|
||||
def baddbmm(x, batch1, batch2, beta=1, alpha=1):
|
||||
r"""
|
||||
Performs a batch matrix-matrix product of matrices in batch1 and batch2. input is added to the final result.
|
||||
|
@ -5959,6 +6006,7 @@ __all__ = [
|
|||
'rot90',
|
||||
'remainder',
|
||||
'accumulate_n',
|
||||
'iou'
|
||||
'iou',
|
||||
'bmm'
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -419,6 +419,11 @@ tensor_operator_registry.register('sigmoid', P.Sigmoid)
|
|||
tensor_operator_registry.register('median', Median)
|
||||
tensor_operator_registry.register('tanh', tanh)
|
||||
tensor_operator_registry.register('exp', P.Exp)
|
||||
tensor_operator_registry.register('addmv', addmv)
|
||||
tensor_operator_registry.register('asinh', asinh)
|
||||
tensor_operator_registry.register('atan', atan)
|
||||
tensor_operator_registry.register('atanh', atanh)
|
||||
tensor_operator_registry.register('bmm', bmm)
|
||||
# ms cannot support Tensor(True) compare
|
||||
tensor_operator_registry.register('__eq__', equal)
|
||||
tensor_operator_registry.register('__ne__', not_equal)
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
|
||||
# all cases tested against dchip
|
||||
|
||||
|
||||
def test_addmv_forward_tensor_api(nptype):
|
||||
"""
|
||||
Feature: test addmv forward tensor api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.array([2., 3.]).astype(nptype))
|
||||
mat = Tensor(np.array([[2., 5., 3.], [4., 2., 2.]]).astype(nptype))
|
||||
vec = Tensor(np.array([3., 2., 4.]).astype(nptype))
|
||||
output = x.addmv(mat, vec)
|
||||
expected = np.array([30., 27.]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_addmv_forward_float32_tensor_api():
|
||||
"""
|
||||
Feature: test addmv forward tensor api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_addmv_forward_tensor_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
test_addmv_forward_tensor_api(np.float32)
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
|
||||
# all cases tested against dchip
|
||||
|
||||
|
||||
def test_asinh_forward_tensor_api(nptype):
|
||||
"""
|
||||
Feature: test asinh forward tensor api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.array([-5.0, 1.5, 3.0, 100.0]).astype(nptype))
|
||||
output = x.asinh()
|
||||
expected = np.array([-2.3124382, 1.1947632, 1.8184465, 5.298342]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_asinh_forward_float32_tensor_api():
|
||||
"""
|
||||
Feature: test asinh forward tensor api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_asinh_forward_tensor_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
test_asinh_forward_tensor_api(np.float32)
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
|
||||
# all cases tested against dchip
|
||||
|
||||
|
||||
def test_atan_forward_tensor_api(nptype):
|
||||
"""
|
||||
Feature: test atan forward tensor api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.array([1.0, 0.0]).astype(nptype))
|
||||
output = x.atan()
|
||||
expected = np.array([0.7853982, 0.0]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_atan_forward_float32_tensor_api():
|
||||
"""
|
||||
Feature: test atan forward tensor api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_atan_forward_tensor_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
test_atan_forward_tensor_api(np.float32)
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
|
||||
# all cases tested against dchip
|
||||
|
||||
|
||||
def test_atanh_forward_tensor_api(nptype):
|
||||
"""
|
||||
Feature: test atanh forward tensor api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.array([0, -0.5]).astype(nptype))
|
||||
output = x.atanh()
|
||||
expected = np.array([0.0, -0.54930615]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_atanh_forward_float32_tensor_api():
|
||||
"""
|
||||
Feature: test atanh forward tensor api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_atanh_forward_tensor_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
test_atanh_forward_tensor_api(np.float32)
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
# all cases tested against dchip
|
||||
|
||||
|
||||
def test_bmm_forward_tensor_api(nptype):
|
||||
"""
|
||||
Feature: test bmm forward tensor api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.ones(shape=[2, 4, 1, 3]).astype(nptype))
|
||||
y = Tensor(np.ones(shape=[2, 4, 3, 4]).astype(nptype))
|
||||
output = x.bmm(y)
|
||||
expected = 3 * np.ones(shape=[2, 4, 1, 4]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_bmm_forward_float32_tensor_api():
|
||||
"""
|
||||
Feature: test bmm forward tensor api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_bmm_forward_tensor_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
test_bmm_forward_tensor_api(np.float32)
|
||||
|
||||
|
||||
def test_bmm_forward_functional_api(nptype):
|
||||
"""
|
||||
Feature: test bmm forward functional api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.ones(shape=[2, 4, 1, 3]).astype(nptype))
|
||||
y = Tensor(np.ones(shape=[2, 4, 3, 4]).astype(nptype))
|
||||
output = F.bmm(x, y)
|
||||
expected = 3 * np.ones(shape=[2, 4, 1, 4]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_bmm_forward_float32_functional_api():
|
||||
"""
|
||||
Feature: test bmm forward functional api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_bmm_forward_functional_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
test_bmm_forward_functional_api(np.float32)
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
def test_addmv_forward_tensor_api(nptype):
|
||||
"""
|
||||
Feature: test addmv forward tensor api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.array([2., 3.]).astype(nptype))
|
||||
mat = Tensor(np.array([[2., 5., 3.], [4., 2., 2.]]).astype(nptype))
|
||||
vec = Tensor(np.array([3., 2., 4.]).astype(nptype))
|
||||
output = x.addmv(mat, vec)
|
||||
expected = np.array([30., 27.]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_addmv_forward_float32_tensor_api():
|
||||
"""
|
||||
Feature: test addmv forward tensor api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
test_addmv_forward_tensor_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
test_addmv_forward_tensor_api(np.float32)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_addmv_forward_float32_tensor_api()
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -50,3 +50,34 @@ def test_asinh(dtype):
|
|||
print(output)
|
||||
expect = np.arcsinh(np_array)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
def test_asinh_forward_tensor_api(nptype):
|
||||
"""
|
||||
Feature: test asinh forward tensor api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.array([-5.0, 1.5, 3.0, 100.0]).astype(nptype))
|
||||
output = x.asinh()
|
||||
expected = np.array([-2.3124382, 1.1947632, 1.8184465, 5.298342]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_asinh_forward_float32_tensor_api():
|
||||
"""
|
||||
Feature: test asinh forward tensor api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
test_asinh_forward_tensor_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
test_asinh_forward_tensor_api(np.float32)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_asinh_forward_float32_tensor_api()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -50,3 +50,34 @@ def test_atan(dtype):
|
|||
print(output)
|
||||
expect = np.arctan(np_array)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
def test_atan_forward_tensor_api(nptype):
|
||||
"""
|
||||
Feature: test atan forward tensor api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.array([1.0, 0.0]).astype(nptype))
|
||||
output = x.atan()
|
||||
expected = np.array([0.7853982, 0.0]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_atan_forward_float32_tensor_api():
|
||||
"""
|
||||
Feature: test atan forward tensor api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
test_atan_forward_tensor_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
test_atan_forward_tensor_api(np.float32)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_atan_forward_float32_tensor_api()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -50,3 +50,34 @@ def test_atanh(dtype):
|
|||
print(output)
|
||||
expect = np.arctanh(np_array)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
def test_atanh_forward_tensor_api(nptype):
|
||||
"""
|
||||
Feature: test atanh forward tensor api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.array([0, -0.5]).astype(nptype))
|
||||
output = x.atanh()
|
||||
expected = np.array([0.0, -0.54930615]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_atanh_forward_float32_tensor_api():
|
||||
"""
|
||||
Feature: test atanh forward tensor api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
test_atanh_forward_tensor_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
test_atanh_forward_tensor_api(np.float32)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_atanh_forward_float32_tensor_api()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -21,6 +21,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
class BatchMatMulNet(nn.Cell):
|
||||
|
@ -166,3 +167,64 @@ def test_4d_transpose_ab():
|
|||
[[5860., 6148., 6436., 6724.],
|
||||
[6043., 6340., 6637., 6934.]]]], np.float16)
|
||||
judge_result_correct(output.asnumpy(), expect)
|
||||
|
||||
|
||||
def test_bmm_forward_tensor_api(nptype):
|
||||
"""
|
||||
Feature: test bmm forward tensor api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.ones(shape=[2, 4, 1, 3]).astype(nptype))
|
||||
y = Tensor(np.ones(shape=[2, 4, 3, 4]).astype(nptype))
|
||||
output = x.bmm(y)
|
||||
expected = 3 * np.ones(shape=[2, 4, 1, 4]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bmm_forward_float32_tensor_api():
|
||||
"""
|
||||
Feature: test bmm forward tensor api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
test_bmm_forward_tensor_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
test_bmm_forward_tensor_api(np.float32)
|
||||
|
||||
|
||||
def test_bmm_forward_functional_api(nptype):
|
||||
"""
|
||||
Feature: test bmm forward functional api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.ones(shape=[2, 4, 1, 3]).astype(nptype))
|
||||
y = Tensor(np.ones(shape=[2, 4, 3, 4]).astype(nptype))
|
||||
output = F.bmm(x, y)
|
||||
expected = 3 * np.ones(shape=[2, 4, 1, 4]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bmm_forward_float32_functional_api():
|
||||
"""
|
||||
Feature: test bmm forward functional api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
test_bmm_forward_functional_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
test_bmm_forward_functional_api(np.float32)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_bmm_forward_float32_tensor_api()
|
||||
test_bmm_forward_float32_functional_api()
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
def test_addmv_forward_tensor_api(nptype):
|
||||
"""
|
||||
Feature: test addmv forward tensor api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.array([2., 3.]).astype(nptype))
|
||||
mat = Tensor(np.array([[2., 5., 3.], [4., 2., 2.]]).astype(nptype))
|
||||
vec = Tensor(np.array([3., 2., 4.]).astype(nptype))
|
||||
output = x.addmv(mat, vec)
|
||||
expected = np.array([30., 27.]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_addmv_forward_float32_tensor_api():
|
||||
"""
|
||||
Feature: test addmv forward tensor api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_addmv_forward_tensor_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
test_addmv_forward_tensor_api(np.float32)
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -39,3 +39,30 @@ def test_asinh_fp16():
|
|||
output_ms = P.Asinh()(Tensor(x_np))
|
||||
output_np = np.arcsinh(x_np.astype(np.float32)).astype(np.float16)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np, 1e-3, 1e-3)
|
||||
|
||||
|
||||
def test_asinh_forward_tensor_api(nptype):
|
||||
"""
|
||||
Feature: test asinh forward tensor api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.array([-5.0, 1.5, 3.0, 100.0]).astype(nptype))
|
||||
output = x.asinh()
|
||||
expected = np.array([-2.3124382, 1.1947632, 1.8184465, 5.298342]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_asinh_forward_float32_tensor_api():
|
||||
"""
|
||||
Feature: test asinh forward tensor api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_asinh_forward_tensor_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
test_asinh_forward_tensor_api(np.float32)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -39,3 +39,30 @@ def test_atan_fp16():
|
|||
output_ms = P.Atan()(Tensor(x_np))
|
||||
output_np = np.arctan(x_np.astype(np.float32)).astype(np.float16)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np, 1e-3, 1e-3)
|
||||
|
||||
|
||||
def test_atan_forward_tensor_api(nptype):
|
||||
"""
|
||||
Feature: test atan forward tensor api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.array([1.0, 0.0]).astype(nptype))
|
||||
output = x.atan()
|
||||
expected = np.array([0.7853982, 0.0]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_atan_forward_float32_tensor_api():
|
||||
"""
|
||||
Feature: test atan forward tensor api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_atan_forward_tensor_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
test_atan_forward_tensor_api(np.float32)
|
||||
|
|
|
@ -109,3 +109,30 @@ def test_atanh_complex128():
|
|||
output_ms = P.Atanh()(Tensor(x_np))
|
||||
expect = atanh(x_np)
|
||||
assert np.allclose(output_ms.asnumpy(), expect)
|
||||
|
||||
|
||||
def test_atanh_forward_tensor_api(nptype):
|
||||
"""
|
||||
Feature: test atanh forward tensor api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.array([0, -0.5]).astype(nptype))
|
||||
output = x.atanh()
|
||||
expected = np.array([0.0, -0.54930615]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_atanh_forward_float32_tensor_api():
|
||||
"""
|
||||
Feature: test atanh forward tensor api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_atanh_forward_tensor_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
test_atanh_forward_tensor_api(np.float32)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -22,6 +22,7 @@ from mindspore import Tensor
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
class BatchMatMulNet(nn.Cell):
|
||||
|
@ -32,6 +33,7 @@ class BatchMatMulNet(nn.Cell):
|
|||
def construct(self, x, y):
|
||||
return self.batch_matmul(x, y)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -153,7 +155,12 @@ def test_4d_transpose_ab():
|
|||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_4D_fp16():
|
||||
def test_4d_fp_16():
|
||||
"""
|
||||
Feature: test BatchMatMul op.
|
||||
Description: test BatchMatMul 4d input dtype float16.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float16)
|
||||
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float16)
|
||||
|
||||
|
@ -172,9 +179,9 @@ def test_4D_fp16():
|
|||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
class BatchMatMul_d(nn.Cell):
|
||||
class BatchMatMulDynamic(nn.Cell):
|
||||
def __init__(self, transpose_a=False, transpose_b=False):
|
||||
super(BatchMatMul_d, self).__init__()
|
||||
super(BatchMatMulDynamic, self).__init__()
|
||||
self.batch_matmul = P.BatchMatMul(transpose_a, transpose_b)
|
||||
self.test_dynamic = inner.GpuConvertToDynamicShape()
|
||||
|
||||
|
@ -190,7 +197,7 @@ class BatchMatMul_d(nn.Cell):
|
|||
def test_batchmatmul_dynamic():
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = BatchMatMul_d()
|
||||
net = BatchMatMulDynamic()
|
||||
|
||||
x1 = np.arange(8).reshape(2, 2, 2).astype(np.float32)
|
||||
y1 = np.arange(28).reshape(2, 2, 7).astype(np.float32)
|
||||
|
@ -205,3 +212,59 @@ def test_batchmatmul_dynamic():
|
|||
output2 = net(Tensor(x2), Tensor(y2))
|
||||
expect2 = np.matmul(x2, y2)
|
||||
assert (output2.asnumpy() == expect2).all()
|
||||
|
||||
|
||||
def test_bmm_forward_tensor_api(nptype):
|
||||
"""
|
||||
Feature: test bmm forward tensor api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.ones(shape=[2, 4, 1, 3]).astype(nptype))
|
||||
y = Tensor(np.ones(shape=[2, 4, 3, 4]).astype(nptype))
|
||||
output = x.bmm(y)
|
||||
expected = 3 * np.ones(shape=[2, 4, 1, 4]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_bmm_forward_float32_tensor_api():
|
||||
"""
|
||||
Feature: test bmm forward tensor api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_bmm_forward_tensor_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
test_bmm_forward_tensor_api(np.float32)
|
||||
|
||||
|
||||
def test_bmm_forward_functional_api(nptype):
|
||||
"""
|
||||
Feature: test bmm forward functional api for given input dtype.
|
||||
Description: test inputs for given input dtype.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.ones(shape=[2, 4, 1, 3]).astype(nptype))
|
||||
y = Tensor(np.ones(shape=[2, 4, 3, 4]).astype(nptype))
|
||||
output = F.bmm(x, y)
|
||||
expected = 3 * np.ones(shape=[2, 4, 1, 4]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_bmm_forward_float32_functional_api():
|
||||
"""
|
||||
Feature: test bmm forward functional api.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_bmm_forward_functional_api(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
test_bmm_forward_functional_api(np.float32)
|
||||
|
|
Loading…
Reference in New Issue