support mH mm msort mT NanToNum
This commit is contained in:
parent
20d1c98691
commit
4130c29364
|
@ -304,6 +304,7 @@ Reduction函数
|
|||
mindspore.ops.dot
|
||||
mindspore.ops.matmul
|
||||
mindspore.ops.matrix_solve
|
||||
mindspore.ops.mm
|
||||
mindspore.ops.ger
|
||||
mindspore.ops.renorm
|
||||
mindspore.ops.tensor_dot
|
||||
|
@ -389,6 +390,8 @@ Array操作
|
|||
mindspore.ops.matrix_diag_part
|
||||
mindspore.ops.matrix_set_diag
|
||||
mindspore.ops.meshgrid
|
||||
mindspore.ops.msort
|
||||
mindspore.ops.nan_to_num
|
||||
mindspore.ops.normal
|
||||
mindspore.ops.nonzero
|
||||
mindspore.ops.numel
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
mindspore.Tensor.mH
|
||||
====================
|
||||
|
||||
.. py:method:: mindspore.Tensor.mH
|
||||
:property:
|
||||
|
||||
访问此属性等价于调用self.adjoint()方法。
|
||||
|
||||
详情请参考 :func:`mindspore.ops.adjoint`。
|
|
@ -0,0 +1,11 @@
|
|||
mindspore.Tensor.mT
|
||||
====================
|
||||
|
||||
.. py:method:: mindspore.Tensor.mT
|
||||
:property:
|
||||
|
||||
返回将最后两个维度交换的Tensor。
|
||||
|
||||
访问x.mT属性等价于调用x.swapaxes(-2, -1)方法。
|
||||
|
||||
详情请参考 :func:`mindspore.Tensor.swapaxes`。
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.mm
|
||||
====================
|
||||
|
||||
.. py:method:: mindspore.Tensor.mm(mat2)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.mm`。
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.msort
|
||||
=======================
|
||||
|
||||
.. py:method:: mindspore.Tensor.msort()
|
||||
|
||||
详情请参考 :func:`mindspore.ops.msort`。
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.nan_to_num
|
||||
============================
|
||||
|
||||
.. py:method:: mindspore.Tensor.nan_to_num(nan=0.0, posinf=None, neginf=None)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.nan_to_num`。
|
|
@ -159,9 +159,14 @@ mindspore.Tensor
|
|||
mindspore.Tensor.max
|
||||
mindspore.Tensor.mean
|
||||
mindspore.Tensor.median
|
||||
mindspore.Tensor.mH
|
||||
mindspore.Tensor.min
|
||||
mindspore.Tensor.minimum
|
||||
mindspore.Tensor.mm
|
||||
mindspore.Tensor.msort
|
||||
mindspore.Tensor.mT
|
||||
mindspore.Tensor.multiply
|
||||
mindspore.Tensor.nan_to_num
|
||||
mindspore.Tensor.narrow
|
||||
mindspore.Tensor.nbytes
|
||||
mindspore.Tensor.ndim
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
mindspore.ops.mm
|
||||
=================
|
||||
|
||||
.. py:function:: mindspore.ops.mm(mat1, mat2)
|
||||
|
||||
计算两个矩阵的乘积。
|
||||
|
||||
如果 `mat1` 是一个 :math:`(n \times m)` 的Tensor,`mat2` 是一个 :math:`(m \times p)` 的Tensor,`out` 则会是一个 :math:`(n \times p)` 的Tensor。
|
||||
|
||||
.. note::
|
||||
此函数不能支持广播。若需要可广播的方法,请参考 :func:`mindspore.ops.matmul`。
|
||||
|
||||
参数:
|
||||
- **mat1** (Tensor) - 矩阵相乘的第一个矩阵, `mat1` 的最后一维度必须和 `mat2` 的第一维度相等。
|
||||
- **mat2** (Tensor) - 矩阵相乘的第二个矩阵, `mat1` 的最后一维度必须和 `mat2` 的第一维度相等。
|
||||
|
||||
返回:
|
||||
Tensor或Scalar,输入的矩阵乘积。
|
||||
|
||||
异常:
|
||||
- **ValueError** - `mat1` 的最后一维度和 `mat2` 的倒数第二维度不相等。
|
||||
- **ValueError** - `mat1` 或者 `mat2` 不是一个矩阵。
|
|
@ -0,0 +1,17 @@
|
|||
mindspore.ops.msort
|
||||
====================
|
||||
|
||||
.. py:function:: mindspore.ops.msort(x)
|
||||
|
||||
将输入Tensor的元素沿其第一维按值升序排序。
|
||||
|
||||
ops.msort(t)相当于ops.Sort(axis=0)(t)[0]。另外可以参考 :class:`mindspore.ops.Sort()`。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 需要排序的输入,类型必须是float16或者float32。
|
||||
|
||||
返回:
|
||||
排序后的Tensor,与输入的shape和dtype一致。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 的类型既不是float16也不是float32。
|
|
@ -0,0 +1,19 @@
|
|||
mindspore.ops.nan_to_num
|
||||
=========================
|
||||
|
||||
.. py:function:: mindspore.ops.nan_to_num(x, nan=0.0, posinf=None, neginf=None)
|
||||
|
||||
将 `x` 中的`NaN`、正无穷大和负无穷大值分别替换为 `nan`, `posinf`, 和 `neginf` 指定的值。默认情况下,NaN替换为0,正无穷替换为 `x` 类型支持的上限,负无穷替换为由 `x` 类型支持的下限。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - shape为 :math:`(x_1, x_2, ..., x_R)` 的tensor。类型必须为float32或float16。
|
||||
- **nan** (float) - 替换 `NaN` 的值。默认值为0.0。
|
||||
- **posinf** (float) - 如果是一个数字,则为替换正无穷的值。如果为None,则将正无穷替换为 `x` 类型支持的上限。默认值为None。
|
||||
- **neginf** (float) - 如果是一个数字,则为替换负无穷的值。如果为None,则将负无穷替换为 `x` 类型支持的下限。默认值为None。
|
||||
|
||||
返回:
|
||||
Tensor,数据shape和类型与 `x` 相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 不是一个Tensor。
|
||||
- **TypeError** - `x` 的类型既不是float16也不是float32。
|
|
@ -165,9 +165,14 @@
|
|||
mindspore.Tensor.max
|
||||
mindspore.Tensor.mean
|
||||
mindspore.Tensor.median
|
||||
mindspore.Tensor.mH
|
||||
mindspore.Tensor.min
|
||||
mindspore.Tensor.minimum
|
||||
mindspore.Tensor.mm
|
||||
mindspore.Tensor.msort
|
||||
mindspore.Tensor.mT
|
||||
mindspore.Tensor.multiply
|
||||
mindspore.Tensor.nan_to_num
|
||||
mindspore.Tensor.narrow
|
||||
mindspore.Tensor.nbytes
|
||||
mindspore.Tensor.ndim
|
||||
|
|
|
@ -304,6 +304,7 @@ Linear Algebraic Functions
|
|||
mindspore.ops.dot
|
||||
mindspore.ops.matmul
|
||||
mindspore.ops.matrix_solve
|
||||
mindspore.ops.mm
|
||||
mindspore.ops.ger
|
||||
mindspore.ops.renorm
|
||||
mindspore.ops.tensor_dot
|
||||
|
@ -389,6 +390,8 @@ Array Operation
|
|||
mindspore.ops.matrix_diag_part
|
||||
mindspore.ops.matrix_set_diag
|
||||
mindspore.ops.meshgrid
|
||||
mindspore.ops.msort
|
||||
mindspore.ops.nan_to_num
|
||||
mindspore.ops.normal
|
||||
mindspore.ops.nonzero
|
||||
mindspore.ops.numel
|
||||
|
|
|
@ -398,8 +398,11 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"mvlgamma", std::string("mvlgamma")}, // mvlgamma()
|
||||
{"matmul", std::string("matmul")}, // matmul()
|
||||
{"maximum", std::string("maximum")}, // maximum()
|
||||
{"msort", std::string("msort")}, // msort()
|
||||
{"mm", std::string("mm")}, // mm()
|
||||
{"mul", std::string("mul")}, // mul()
|
||||
{"multiply", std::string("multiply")}, // multiply()
|
||||
{"nan_to_num", std::string("nan_to_num")}, // nan_to_num()
|
||||
{"neg", std::string("neg")}, // neg()
|
||||
{"ne", std::string("ne")}, // ne()
|
||||
{"sinh", std::string("sinh")}, // sinh()
|
||||
|
@ -461,6 +464,8 @@ BuiltInTypeMap &GetAttrMap() {
|
|||
{"itemsize", std::string("itemsize_")}, // C.itemsize_
|
||||
{"nbytes", std::string("nbytes_")}, // C.nbytes_
|
||||
{"strides", std::string("strides_")}, // C.strides_
|
||||
{"mH", std::string("adjoint")}, // C.adjoint
|
||||
{"mT", std::string("mT")}, // C.mT_
|
||||
}},
|
||||
{kObjectTypeRowTensorType,
|
||||
{
|
||||
|
|
|
@ -856,6 +856,35 @@ def median(x, global_median, axis=0, keep_dims=False):
|
|||
return median_(x)
|
||||
|
||||
|
||||
def msort(x):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.msort`.
|
||||
"""
|
||||
return F.msort(x)
|
||||
|
||||
|
||||
def mm(mat1, mat2):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.mm`.
|
||||
"""
|
||||
return F.mm(mat1, mat2)
|
||||
|
||||
|
||||
def mT(x):
|
||||
"""
|
||||
Returns a view of this tensor with the last two dimensions transposed.
|
||||
x.mT is equivalent to x.transpose(-2, -1).
|
||||
"""
|
||||
return swapaxes(x, -2, -1)
|
||||
|
||||
|
||||
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.nan_to_num`.
|
||||
"""
|
||||
return F.nan_to_num(x, nan, posinf, neginf)
|
||||
|
||||
|
||||
def cumsum(x, axis=None, dtype=None):
|
||||
"""
|
||||
Returns the cumulative sum of the elements along a given axis.
|
||||
|
|
|
@ -4121,6 +4121,23 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('lstsq')(self, A)
|
||||
|
||||
@property
|
||||
def mH(self):
|
||||
r"""
|
||||
Accessing this property is equivalent to Calling self.adjoint().
|
||||
For details, please refer to :func:`mindspore.ops.adjoint`.
|
||||
"""
|
||||
return self.adjoint()
|
||||
|
||||
@property
|
||||
def mT(self):
|
||||
r"""
|
||||
Returns a view of this tensor with the last two dimensions transposed.
|
||||
x.mT is equivalent to x.swapaxes(-2, -1).
|
||||
For details, please refer to :func:`mindspore.Tensor.swapaxes`.
|
||||
"""
|
||||
return self.swapaxes(-2, -1)
|
||||
|
||||
def mvlgamma(self, p):
|
||||
r"""
|
||||
Computes the multivariate log-gamma function with dimension p element-wise.
|
||||
|
@ -4243,6 +4260,20 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('maximum')(self, other)
|
||||
|
||||
def mm(self, mat2):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.mm`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('mm')(self, mat2)
|
||||
|
||||
def msort(self):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.msort`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('msort')(self)
|
||||
|
||||
def mul(self, value):
|
||||
r"""
|
||||
Multiplies two tensors element-wise.
|
||||
|
@ -4282,6 +4313,12 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('mul')(self, value)
|
||||
|
||||
def nan_to_num(self, nan=0.0, posinf=None, neginf=None):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.nan_to_num`.
|
||||
"""
|
||||
return tensor_operator_registry.get('nan_to_num')(self, nan, posinf, neginf)
|
||||
|
||||
def neg(self):
|
||||
r"""
|
||||
Returns a tensor with negative values of the input tensor element-wise.
|
||||
|
|
|
@ -28,7 +28,7 @@ from mindspore.ops.composite.multitype_ops.add_impl import hyper_add
|
|||
from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like
|
||||
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from mindspore.ops.composite.random_ops import normal, laplace, uniform, gamma, poisson, multinomial
|
||||
from mindspore.ops.composite.math_ops import count_nonzero, tensor_dot, dot, batch_dot, matmul, cummin
|
||||
from mindspore.ops.composite.math_ops import count_nonzero, tensor_dot, dot, batch_dot, matmul, cummin, mm
|
||||
from mindspore.ops.composite.array_ops import repeat_interleave, repeat_elements, sequence_mask
|
||||
from mindspore.ops.composite.vmap_ops import _VmapGeneralPreprocess, _VmapGeneralRule
|
||||
from mindspore.ops.function.clip_func import clip_by_value
|
||||
|
@ -62,6 +62,7 @@ __all__ = [
|
|||
'repeat_interleave',
|
||||
'sequence_mask',
|
||||
'matmul',
|
||||
'mm',
|
||||
'_Grad',
|
||||
'_Vmap',
|
||||
'_VmapGeneralPreprocess']
|
||||
|
|
|
@ -731,6 +731,48 @@ def matmul(x1, x2, dtype=None):
|
|||
return res
|
||||
|
||||
|
||||
def mm(mat1, mat2):
|
||||
r"""
|
||||
Returns the matrix product of two arrays.
|
||||
If `mat1` is a :math:`(n \times m)` Tensor, `mat2` is a
|
||||
:math:`(m \times p)` Tensor, `out` will be a :math:`(n \times p)` Tensor.
|
||||
|
||||
Note:
|
||||
This function does not broadcast. For broadcasting matrix products, see :func:`mindspore.ops.matmul`.
|
||||
|
||||
Args:
|
||||
mat1 (Tensor): The first matrix to be matrix multiplied.
|
||||
The last dimension of `mat1` must be the same size as the first dimension of `mat2`.
|
||||
mat2 (Tensor): The second matrix to be matrix multiplied.
|
||||
The last dimension of `mat1` must be the same size as the first dimension of `mat2`.
|
||||
|
||||
Returns:
|
||||
Tensor or scalar, the matrix product of the inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If the last dimension of `mat1` is not the same size as the
|
||||
second-to-last dimension of `mat2`.
|
||||
ValueError: If `mat1` or `mat2` is not a matrix.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> import mindspore.ops as ops
|
||||
>>> import numpy as np
|
||||
>>> x1 = ms.Tensor(np.random.rand(2, 3))
|
||||
>>> x2 = ms.Tensor(np.random.rand(3, 4))
|
||||
>>> out = ops.mm(x1, x2)
|
||||
>>> print(out.shape)
|
||||
(2, 4)
|
||||
"""
|
||||
if mat1.ndim != 2 or mat2.ndim != 2:
|
||||
raise ValueError(f"For mm, the input tensor must be a matrix, "
|
||||
f"but got mat1.ndim:{mat1.ndim}, mat2.ndim:{mat2.ndim}")
|
||||
return matmul(mat1, mat2)
|
||||
|
||||
|
||||
def cummin(x, axis):
|
||||
r"""
|
||||
Returns a tuple (values,indices) where 'values' is the cumulative minimum value of input Tensor `x`
|
||||
|
|
|
@ -198,6 +198,7 @@ from .math_func import (
|
|||
matrix_solve,
|
||||
maximum,
|
||||
median,
|
||||
nan_to_num,
|
||||
logaddexp,
|
||||
logaddexp2,
|
||||
logit,
|
||||
|
@ -389,6 +390,7 @@ from .nn_func import (
|
|||
lp_pool1d,
|
||||
lp_pool2d,
|
||||
mse_loss,
|
||||
msort
|
||||
)
|
||||
from .linalg_func import (
|
||||
svd,
|
||||
|
|
|
@ -7346,6 +7346,7 @@ __all__ = [
|
|||
'tensor_mul',
|
||||
'mul',
|
||||
'multiply',
|
||||
'nan_to_num',
|
||||
'tensor_div',
|
||||
'div',
|
||||
'divide',
|
||||
|
|
|
@ -4965,6 +4965,38 @@ def mse_loss(input_x, target, reduction='mean'):
|
|||
return _get_cache_prim(P.Cast)()(x, input_dtype)
|
||||
|
||||
|
||||
def msort(x):
|
||||
r"""
|
||||
Sorts the elements of the input tensor along its first dimension in ascending order by value.
|
||||
|
||||
ops.msort(t) is equivalent to ops.Sort(axis=0)(t)[0]. See also :class:`mindspore.ops.Sort()`.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input to sort, with float16 or float32 data type.
|
||||
|
||||
Returns:
|
||||
A tensor whose values are the sorted values, with the same shape and data type as input.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> import mindspore.ops as ops
|
||||
>>> import numpy as np
|
||||
>>> x = Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), ms.float16)
|
||||
>>> output = ops.msort(x)
|
||||
>>> print(output)
|
||||
[[4. 2. 1.]
|
||||
[5. 6. 3.]
|
||||
[8. 9. 7.]]
|
||||
"""
|
||||
return ops.Sort(axis=0)(x)[0]
|
||||
|
||||
|
||||
__all__ = [
|
||||
'adaptive_avg_pool1d',
|
||||
'adaptive_avg_pool2d',
|
||||
|
@ -5037,5 +5069,6 @@ __all__ = [
|
|||
'max_unpool2d',
|
||||
'max_unpool3d',
|
||||
'mse_loss',
|
||||
'msort',
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -30,6 +30,7 @@ from mindspore.ops.operations.array_ops import UniqueConsecutive, Triu
|
|||
from mindspore.ops.operations.nn_ops import AdaptiveMaxPool2D
|
||||
from mindspore.ops.operations._inner_ops import Roll
|
||||
from mindspore.ops.composite.array_ops import repeat_interleave
|
||||
from mindspore.ops.composite.math_ops import mm
|
||||
|
||||
typeof = Primitive('typeof')
|
||||
hastype = Primitive('hastype')
|
||||
|
@ -288,6 +289,9 @@ tensor_operator_registry.register('dense_to_sparse_coo', dense_to_sparse_coo)
|
|||
tensor_operator_registry.register('csr_to_dense', csr_to_dense)
|
||||
tensor_operator_registry.register('narrow', narrow)
|
||||
tensor_operator_registry.register('sort', sort)
|
||||
tensor_operator_registry.register('msort', msort)
|
||||
tensor_operator_registry.register('mm', mm)
|
||||
tensor_operator_registry.register('nan_to_num', nan_to_num)
|
||||
tensor_operator_registry.register('csr_to_coo', csr_to_coo)
|
||||
tensor_operator_registry.register('zeros', zeros)
|
||||
tensor_operator_registry.register('unsorted_segment_min', unsorted_segment_min)
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x1, x2):
|
||||
output = ops.mm(x1, x2)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_mm_normal(mode):
|
||||
"""
|
||||
Feature: mm
|
||||
Description: Verify the result of mm
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x1 = ms.Tensor(np.arange(6).reshape((2, 3)), dtype=ms.float32)
|
||||
x2 = ms.Tensor(np.arange(12).reshape((3, 4)), dtype=ms.float32)
|
||||
out = net(x1, x2)
|
||||
expect_out = np.array([[20, 23, 26, 29],
|
||||
[56, 68, 80, 92]])
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
output = ops.msort(x)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_msort_normal(mode):
|
||||
"""
|
||||
Feature: msort
|
||||
Description: Verify the result of msort
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), ms.float16)
|
||||
out = net(x)
|
||||
expect_out = np.array([[4., 2., 1.],
|
||||
[5., 6., 3.],
|
||||
[8., 9., 7.]])
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, nan, posinf, neginf):
|
||||
output = ops.nan_to_num(x, nan, posinf, neginf)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_nan_to_num_normal(mode):
|
||||
"""
|
||||
Feature: nan_to_num
|
||||
Description: Verify the result of nan_to_num
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor(np.array([float('nan'), float('inf'), -float('inf'), 3.14]), ms.float32)
|
||||
out = net(x, 1.0, 2.0, 3.0)
|
||||
expect_out = np.array([1., 2., 3., 3.14])
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x.mH
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_mH_normal(mode):
|
||||
"""
|
||||
Feature: mH
|
||||
Description: Verify the result of mH
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = ms.Tensor(np.array([[0., 1.], [2., 3.]]), ms.float32)
|
||||
net = Net()
|
||||
output = net(x)
|
||||
expect_output = np.array([[0., 2.],
|
||||
[1., 3.]])
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x.mT
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_mT_normal(mode):
|
||||
"""
|
||||
Feature: mT
|
||||
Description: Verify the result of mT
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = ms.Tensor(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]), ms.float32)
|
||||
net = Net()
|
||||
output = net(x)
|
||||
expect_output = np.array([[[1, 4],
|
||||
[2, 5],
|
||||
[3, 6]],
|
||||
|
||||
[[7, 10],
|
||||
[8, 11],
|
||||
[9, 12]]])
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x1, x2):
|
||||
output = x1.mm(x2)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_mm_normal(mode):
|
||||
"""
|
||||
Feature: mm
|
||||
Description: Verify the result of mm
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x1 = ms.Tensor(np.arange(6).reshape((2, 3)), dtype=ms.float32)
|
||||
x2 = ms.Tensor(np.arange(12).reshape((3, 4)), dtype=ms.float32)
|
||||
out = net(x1, x2)
|
||||
expect_out = np.array([[20, 23, 26, 29],
|
||||
[56, 68, 80, 92]])
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
output = x.msort()
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_msort_normal(mode):
|
||||
"""
|
||||
Feature: msort
|
||||
Description: Verify the result of msort
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), ms.float16)
|
||||
out = net(x)
|
||||
expect_out = np.array([[4., 2., 1.],
|
||||
[5., 6., 3.],
|
||||
[8., 9., 7.]])
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, nan, posinf, neginf):
|
||||
output = x.nan_to_num(nan, posinf, neginf)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_nan_to_num_normal(mode):
|
||||
"""
|
||||
Feature: nan_to_num
|
||||
Description: Verify the result of nan_to_num
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor(np.array([float('nan'), float('inf'), -float('inf'), 3.14]), ms.float32)
|
||||
out = net(x, 1.0, 2.0, 3.0)
|
||||
expect_out = np.array([1., 2., 3., 3.14])
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
Loading…
Reference in New Issue