support mH mm msort mT NanToNum

This commit is contained in:
fengyihang 2022-11-14 11:09:31 +08:00
parent 20d1c98691
commit 4130c29364
29 changed files with 661 additions and 1 deletions

View File

@ -304,6 +304,7 @@ Reduction函数
mindspore.ops.dot
mindspore.ops.matmul
mindspore.ops.matrix_solve
mindspore.ops.mm
mindspore.ops.ger
mindspore.ops.renorm
mindspore.ops.tensor_dot
@ -389,6 +390,8 @@ Array操作
mindspore.ops.matrix_diag_part
mindspore.ops.matrix_set_diag
mindspore.ops.meshgrid
mindspore.ops.msort
mindspore.ops.nan_to_num
mindspore.ops.normal
mindspore.ops.nonzero
mindspore.ops.numel

View File

@ -0,0 +1,9 @@
mindspore.Tensor.mH
====================
.. py:method:: mindspore.Tensor.mH
:property:
访问此属性等价于调用self.adjoint()方法。
详情请参考 :func:`mindspore.ops.adjoint`

View File

@ -0,0 +1,11 @@
mindspore.Tensor.mT
====================
.. py:method:: mindspore.Tensor.mT
:property:
返回将最后两个维度交换的Tensor。
访问x.mT属性等价于调用x.swapaxes(-2, -1)方法。
详情请参考 :func:`mindspore.Tensor.swapaxes`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.mm
====================
.. py:method:: mindspore.Tensor.mm(mat2)
详情请参考 :func:`mindspore.ops.mm`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.msort
=======================
.. py:method:: mindspore.Tensor.msort()
详情请参考 :func:`mindspore.ops.msort`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.nan_to_num
============================
.. py:method:: mindspore.Tensor.nan_to_num(nan=0.0, posinf=None, neginf=None)
详情请参考 :func:`mindspore.ops.nan_to_num`

View File

@ -159,9 +159,14 @@ mindspore.Tensor
mindspore.Tensor.max
mindspore.Tensor.mean
mindspore.Tensor.median
mindspore.Tensor.mH
mindspore.Tensor.min
mindspore.Tensor.minimum
mindspore.Tensor.mm
mindspore.Tensor.msort
mindspore.Tensor.mT
mindspore.Tensor.multiply
mindspore.Tensor.nan_to_num
mindspore.Tensor.narrow
mindspore.Tensor.nbytes
mindspore.Tensor.ndim

View File

@ -0,0 +1,22 @@
mindspore.ops.mm
=================
.. py:function:: mindspore.ops.mm(mat1, mat2)
计算两个矩阵的乘积。
如果 `mat1` 是一个 :math:`(n \times m)` 的Tensor`mat2` 是一个 :math:`(m \times p)` 的Tensor`out` 则会是一个 :math:`(n \times p)` 的Tensor。
.. note::
此函数不能支持广播。若需要可广播的方法,请参考 :func:`mindspore.ops.matmul`
参数:
- **mat1** (Tensor) - 矩阵相乘的第一个矩阵, `mat1` 的最后一维度必须和 `mat2` 的第一维度相等。
- **mat2** (Tensor) - 矩阵相乘的第二个矩阵, `mat1` 的最后一维度必须和 `mat2` 的第一维度相等。
返回:
Tensor或Scalar输入的矩阵乘积。
异常:
- **ValueError** - `mat1` 的最后一维度和 `mat2` 的倒数第二维度不相等。
- **ValueError** - `mat1` 或者 `mat2` 不是一个矩阵。

View File

@ -0,0 +1,17 @@
mindspore.ops.msort
====================
.. py:function:: mindspore.ops.msort(x)
将输入Tensor的元素沿其第一维按值升序排序。
ops.msort(t)相当于ops.Sort(axis=0)(t)[0]。另外可以参考 :class:`mindspore.ops.Sort()`
参数:
- **x** (Tensor) - 需要排序的输入类型必须是float16或者float32。
返回:
排序后的Tensor与输入的shape和dtype一致。
异常:
- **TypeError** - `x` 的类型既不是float16也不是float32。

View File

@ -0,0 +1,19 @@
mindspore.ops.nan_to_num
=========================
.. py:function:: mindspore.ops.nan_to_num(x, nan=0.0, posinf=None, neginf=None)
`x` 中的`NaN`、正无穷大和负无穷大值分别替换为 `nan`, `posinf`, 和 `neginf` 指定的值。默认情况下NaN替换为0正无穷替换为 `x` 类型支持的上限,负无穷替换为由 `x` 类型支持的下限。
参数:
- **x** (Tensor) - shape为 :math:`(x_1, x_2, ..., x_R)` 的tensor。类型必须为float32或float16。
- **nan** (float) - 替换 `NaN` 的值。默认值为0.0。
- **posinf** (float) - 如果是一个数字则为替换正无穷的值。如果为None则将正无穷替换为 `x` 类型支持的上限。默认值为None。
- **neginf** (float) - 如果是一个数字则为替换负无穷的值。如果为None则将负无穷替换为 `x` 类型支持的下限。默认值为None。
返回:
Tensor数据shape和类型与 `x` 相同。
异常:
- **TypeError** - `x` 不是一个Tensor。
- **TypeError** - `x` 的类型既不是float16也不是float32。

View File

@ -165,9 +165,14 @@
mindspore.Tensor.max
mindspore.Tensor.mean
mindspore.Tensor.median
mindspore.Tensor.mH
mindspore.Tensor.min
mindspore.Tensor.minimum
mindspore.Tensor.mm
mindspore.Tensor.msort
mindspore.Tensor.mT
mindspore.Tensor.multiply
mindspore.Tensor.nan_to_num
mindspore.Tensor.narrow
mindspore.Tensor.nbytes
mindspore.Tensor.ndim

View File

@ -304,6 +304,7 @@ Linear Algebraic Functions
mindspore.ops.dot
mindspore.ops.matmul
mindspore.ops.matrix_solve
mindspore.ops.mm
mindspore.ops.ger
mindspore.ops.renorm
mindspore.ops.tensor_dot
@ -389,6 +390,8 @@ Array Operation
mindspore.ops.matrix_diag_part
mindspore.ops.matrix_set_diag
mindspore.ops.meshgrid
mindspore.ops.msort
mindspore.ops.nan_to_num
mindspore.ops.normal
mindspore.ops.nonzero
mindspore.ops.numel

View File

@ -398,8 +398,11 @@ BuiltInTypeMap &GetMethodMap() {
{"mvlgamma", std::string("mvlgamma")}, // mvlgamma()
{"matmul", std::string("matmul")}, // matmul()
{"maximum", std::string("maximum")}, // maximum()
{"msort", std::string("msort")}, // msort()
{"mm", std::string("mm")}, // mm()
{"mul", std::string("mul")}, // mul()
{"multiply", std::string("multiply")}, // multiply()
{"nan_to_num", std::string("nan_to_num")}, // nan_to_num()
{"neg", std::string("neg")}, // neg()
{"ne", std::string("ne")}, // ne()
{"sinh", std::string("sinh")}, // sinh()
@ -461,6 +464,8 @@ BuiltInTypeMap &GetAttrMap() {
{"itemsize", std::string("itemsize_")}, // C.itemsize_
{"nbytes", std::string("nbytes_")}, // C.nbytes_
{"strides", std::string("strides_")}, // C.strides_
{"mH", std::string("adjoint")}, // C.adjoint
{"mT", std::string("mT")}, // C.mT_
}},
{kObjectTypeRowTensorType,
{

View File

@ -856,6 +856,35 @@ def median(x, global_median, axis=0, keep_dims=False):
return median_(x)
def msort(x):
"""
For details, please refer to :func:`mindspore.ops.msort`.
"""
return F.msort(x)
def mm(mat1, mat2):
"""
For details, please refer to :func:`mindspore.ops.mm`.
"""
return F.mm(mat1, mat2)
def mT(x):
"""
Returns a view of this tensor with the last two dimensions transposed.
x.mT is equivalent to x.transpose(-2, -1).
"""
return swapaxes(x, -2, -1)
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
"""
For details, please refer to :func:`mindspore.ops.nan_to_num`.
"""
return F.nan_to_num(x, nan, posinf, neginf)
def cumsum(x, axis=None, dtype=None):
"""
Returns the cumulative sum of the elements along a given axis.

View File

@ -4121,6 +4121,23 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('lstsq')(self, A)
@property
def mH(self):
r"""
Accessing this property is equivalent to Calling self.adjoint().
For details, please refer to :func:`mindspore.ops.adjoint`.
"""
return self.adjoint()
@property
def mT(self):
r"""
Returns a view of this tensor with the last two dimensions transposed.
x.mT is equivalent to x.swapaxes(-2, -1).
For details, please refer to :func:`mindspore.Tensor.swapaxes`.
"""
return self.swapaxes(-2, -1)
def mvlgamma(self, p):
r"""
Computes the multivariate log-gamma function with dimension p element-wise.
@ -4243,6 +4260,20 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('maximum')(self, other)
def mm(self, mat2):
r"""
For details, please refer to :func:`mindspore.ops.mm`.
"""
self._init_check()
return tensor_operator_registry.get('mm')(self, mat2)
def msort(self):
r"""
For details, please refer to :func:`mindspore.ops.msort`.
"""
self._init_check()
return tensor_operator_registry.get('msort')(self)
def mul(self, value):
r"""
Multiplies two tensors element-wise.
@ -4282,6 +4313,12 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('mul')(self, value)
def nan_to_num(self, nan=0.0, posinf=None, neginf=None):
"""
For details, please refer to :func:`mindspore.ops.nan_to_num`.
"""
return tensor_operator_registry.get('nan_to_num')(self, nan, posinf, neginf)
def neg(self):
r"""
Returns a tensor with negative values of the input tensor element-wise.

View File

@ -28,7 +28,7 @@ from mindspore.ops.composite.multitype_ops.add_impl import hyper_add
from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
from mindspore.ops.composite.random_ops import normal, laplace, uniform, gamma, poisson, multinomial
from mindspore.ops.composite.math_ops import count_nonzero, tensor_dot, dot, batch_dot, matmul, cummin
from mindspore.ops.composite.math_ops import count_nonzero, tensor_dot, dot, batch_dot, matmul, cummin, mm
from mindspore.ops.composite.array_ops import repeat_interleave, repeat_elements, sequence_mask
from mindspore.ops.composite.vmap_ops import _VmapGeneralPreprocess, _VmapGeneralRule
from mindspore.ops.function.clip_func import clip_by_value
@ -62,6 +62,7 @@ __all__ = [
'repeat_interleave',
'sequence_mask',
'matmul',
'mm',
'_Grad',
'_Vmap',
'_VmapGeneralPreprocess']

View File

@ -731,6 +731,48 @@ def matmul(x1, x2, dtype=None):
return res
def mm(mat1, mat2):
r"""
Returns the matrix product of two arrays.
If `mat1` is a :math:`(n \times m)` Tensor, `mat2` is a
:math:`(m \times p)` Tensor, `out` will be a :math:`(n \times p)` Tensor.
Note:
This function does not broadcast. For broadcasting matrix products, see :func:`mindspore.ops.matmul`.
Args:
mat1 (Tensor): The first matrix to be matrix multiplied.
The last dimension of `mat1` must be the same size as the first dimension of `mat2`.
mat2 (Tensor): The second matrix to be matrix multiplied.
The last dimension of `mat1` must be the same size as the first dimension of `mat2`.
Returns:
Tensor or scalar, the matrix product of the inputs.
Raises:
ValueError: If the last dimension of `mat1` is not the same size as the
second-to-last dimension of `mat2`.
ValueError: If `mat1` or `mat2` is not a matrix.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore as ms
>>> import mindspore.ops as ops
>>> import numpy as np
>>> x1 = ms.Tensor(np.random.rand(2, 3))
>>> x2 = ms.Tensor(np.random.rand(3, 4))
>>> out = ops.mm(x1, x2)
>>> print(out.shape)
(2, 4)
"""
if mat1.ndim != 2 or mat2.ndim != 2:
raise ValueError(f"For mm, the input tensor must be a matrix, "
f"but got mat1.ndim:{mat1.ndim}, mat2.ndim:{mat2.ndim}")
return matmul(mat1, mat2)
def cummin(x, axis):
r"""
Returns a tuple (values,indices) where 'values' is the cumulative minimum value of input Tensor `x`

View File

@ -198,6 +198,7 @@ from .math_func import (
matrix_solve,
maximum,
median,
nan_to_num,
logaddexp,
logaddexp2,
logit,
@ -389,6 +390,7 @@ from .nn_func import (
lp_pool1d,
lp_pool2d,
mse_loss,
msort
)
from .linalg_func import (
svd,

View File

@ -7346,6 +7346,7 @@ __all__ = [
'tensor_mul',
'mul',
'multiply',
'nan_to_num',
'tensor_div',
'div',
'divide',

View File

@ -4965,6 +4965,38 @@ def mse_loss(input_x, target, reduction='mean'):
return _get_cache_prim(P.Cast)()(x, input_dtype)
def msort(x):
r"""
Sorts the elements of the input tensor along its first dimension in ascending order by value.
ops.msort(t) is equivalent to ops.Sort(axis=0)(t)[0]. See also :class:`mindspore.ops.Sort()`.
Args:
x (Tensor): The input to sort, with float16 or float32 data type.
Returns:
A tensor whose values are the sorted values, with the same shape and data type as input.
Raises:
TypeError: If dtype of `x` is neither float16 nor float32.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> import mindspore as ms
>>> import mindspore.ops as ops
>>> import numpy as np
>>> x = Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), ms.float16)
>>> output = ops.msort(x)
>>> print(output)
[[4. 2. 1.]
[5. 6. 3.]
[8. 9. 7.]]
"""
return ops.Sort(axis=0)(x)[0]
__all__ = [
'adaptive_avg_pool1d',
'adaptive_avg_pool2d',
@ -5037,5 +5069,6 @@ __all__ = [
'max_unpool2d',
'max_unpool3d',
'mse_loss',
'msort',
]
__all__.sort()

View File

@ -30,6 +30,7 @@ from mindspore.ops.operations.array_ops import UniqueConsecutive, Triu
from mindspore.ops.operations.nn_ops import AdaptiveMaxPool2D
from mindspore.ops.operations._inner_ops import Roll
from mindspore.ops.composite.array_ops import repeat_interleave
from mindspore.ops.composite.math_ops import mm
typeof = Primitive('typeof')
hastype = Primitive('hastype')
@ -288,6 +289,9 @@ tensor_operator_registry.register('dense_to_sparse_coo', dense_to_sparse_coo)
tensor_operator_registry.register('csr_to_dense', csr_to_dense)
tensor_operator_registry.register('narrow', narrow)
tensor_operator_registry.register('sort', sort)
tensor_operator_registry.register('msort', msort)
tensor_operator_registry.register('mm', mm)
tensor_operator_registry.register('nan_to_num', nan_to_num)
tensor_operator_registry.register('csr_to_coo', csr_to_coo)
tensor_operator_registry.register('zeros', zeros)
tensor_operator_registry.register('unsorted_segment_min', unsorted_segment_min)

View File

@ -0,0 +1,51 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x1, x2):
output = ops.mm(x1, x2)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_mm_normal(mode):
"""
Feature: mm
Description: Verify the result of mm
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x1 = ms.Tensor(np.arange(6).reshape((2, 3)), dtype=ms.float32)
x2 = ms.Tensor(np.arange(12).reshape((3, 4)), dtype=ms.float32)
out = net(x1, x2)
expect_out = np.array([[20, 23, 26, 29],
[56, 68, 80, 92]])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,51 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x):
output = ops.msort(x)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_msort_normal(mode):
"""
Feature: msort
Description: Verify the result of msort
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), ms.float16)
out = net(x)
expect_out = np.array([[4., 2., 1.],
[5., 6., 3.],
[8., 9., 7.]])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,46 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x, nan, posinf, neginf):
output = ops.nan_to_num(x, nan, posinf, neginf)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_nan_to_num_normal(mode):
"""
Feature: nan_to_num
Description: Verify the result of nan_to_num
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor(np.array([float('nan'), float('inf'), -float('inf'), 3.14]), ms.float32)
out = net(x, 1.0, 2.0, 3.0)
expect_out = np.array([1., 2., 3., 3.14])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,48 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x):
return x.mH
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_mH_normal(mode):
"""
Feature: mH
Description: Verify the result of mH
Expectation: success
"""
ms.set_context(mode=mode)
x = ms.Tensor(np.array([[0., 1.], [2., 3.]]), ms.float32)
net = Net()
output = net(x)
expect_output = np.array([[0., 2.],
[1., 3.]])
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,53 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x):
return x.mT
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_mT_normal(mode):
"""
Feature: mT
Description: Verify the result of mT
Expectation: success
"""
ms.set_context(mode=mode)
x = ms.Tensor(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]), ms.float32)
net = Net()
output = net(x)
expect_output = np.array([[[1, 4],
[2, 5],
[3, 6]],
[[7, 10],
[8, 11],
[9, 12]]])
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,50 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x1, x2):
output = x1.mm(x2)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_mm_normal(mode):
"""
Feature: mm
Description: Verify the result of mm
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x1 = ms.Tensor(np.arange(6).reshape((2, 3)), dtype=ms.float32)
x2 = ms.Tensor(np.arange(12).reshape((3, 4)), dtype=ms.float32)
out = net(x1, x2)
expect_out = np.array([[20, 23, 26, 29],
[56, 68, 80, 92]])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,50 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x):
output = x.msort()
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_msort_normal(mode):
"""
Feature: msort
Description: Verify the result of msort
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), ms.float16)
out = net(x)
expect_out = np.array([[4., 2., 1.],
[5., 6., 3.],
[8., 9., 7.]])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,45 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x, nan, posinf, neginf):
output = x.nan_to_num(nan, posinf, neginf)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_nan_to_num_normal(mode):
"""
Feature: nan_to_num
Description: Verify the result of nan_to_num
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor(np.array([float('nan'), float('inf'), -float('inf'), 3.14]), ms.float32)
out = net(x, 1.0, 2.0, 3.0)
expect_out = np.array([1., 2., 3., 3.14])
assert np.allclose(out.asnumpy(), expect_out)