forked from mindspore-Ecosystem/mindspore
!46729 api: slogdet tril
Merge pull request !46729 from 于振华/api_tril_slogdet_1209
This commit is contained in:
commit
5db6dbb8d4
|
@ -463,6 +463,7 @@ Array操作
|
||||||
mindspore.ops.shape
|
mindspore.ops.shape
|
||||||
mindspore.ops.size
|
mindspore.ops.size
|
||||||
mindspore.ops.slice
|
mindspore.ops.slice
|
||||||
|
mindspore.ops.slogdet
|
||||||
mindspore.ops.space_to_batch_nd
|
mindspore.ops.space_to_batch_nd
|
||||||
mindspore.ops.sparse_segment_mean
|
mindspore.ops.sparse_segment_mean
|
||||||
mindspore.ops.split
|
mindspore.ops.split
|
||||||
|
@ -479,6 +480,7 @@ Array操作
|
||||||
mindspore.ops.tensor_scatter_elements
|
mindspore.ops.tensor_scatter_elements
|
||||||
mindspore.ops.tensor_split
|
mindspore.ops.tensor_split
|
||||||
mindspore.ops.tile
|
mindspore.ops.tile
|
||||||
|
mindspore.ops.tril
|
||||||
mindspore.ops.top_k
|
mindspore.ops.top_k
|
||||||
mindspore.ops.transpose
|
mindspore.ops.transpose
|
||||||
mindspore.ops.unbind
|
mindspore.ops.unbind
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
mindspore.Tensor.slogdet
|
||||||
|
========================
|
||||||
|
|
||||||
|
.. py:method:: mindspore.Tensor.slogdet()
|
||||||
|
|
||||||
|
详情请参考 :func:`mindspore.ops.slogdet`。
|
|
@ -0,0 +1,6 @@
|
||||||
|
mindspore.Tensor.tril
|
||||||
|
=====================
|
||||||
|
|
||||||
|
.. py:method:: mindspore.Tensor.tril(diagonal=0)
|
||||||
|
|
||||||
|
详情请参考 :func:`mindspore.ops.tril`。
|
|
@ -245,6 +245,7 @@ mindspore.Tensor
|
||||||
mindspore.Tensor.sin
|
mindspore.Tensor.sin
|
||||||
mindspore.Tensor.sinc
|
mindspore.Tensor.sinc
|
||||||
mindspore.Tensor.size
|
mindspore.Tensor.size
|
||||||
|
mindspore.Tensor.slogdet
|
||||||
mindspore.Tensor.soft_shrink
|
mindspore.Tensor.soft_shrink
|
||||||
mindspore.Tensor.split
|
mindspore.Tensor.split
|
||||||
mindspore.Tensor.sqrt
|
mindspore.Tensor.sqrt
|
||||||
|
@ -272,6 +273,7 @@ mindspore.Tensor
|
||||||
mindspore.Tensor.top_k
|
mindspore.Tensor.top_k
|
||||||
mindspore.Tensor.trace
|
mindspore.Tensor.trace
|
||||||
mindspore.Tensor.transpose
|
mindspore.Tensor.transpose
|
||||||
|
mindspore.Tensor.tril
|
||||||
mindspore.Tensor.triu
|
mindspore.Tensor.triu
|
||||||
mindspore.Tensor.true_divide
|
mindspore.Tensor.true_divide
|
||||||
mindspore.Tensor.unbind
|
mindspore.Tensor.unbind
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
mindspore.ops.slogdet
|
||||||
|
=====================
|
||||||
|
|
||||||
|
.. py:function:: mindspore.ops.slogdet(x)
|
||||||
|
|
||||||
|
对一个或多个方阵行列式的绝对值取对数,返回其符号和值。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **x** (Tensor) - 输入Tensor,shape为 :math:`[..., M, M]` 。矩阵必须至少有两个维度,最后两个维度尺寸必须相同。支持的数据类型为float32、float64、complex64或complex128。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Tensor,行列式的绝对值的对数的符号,shape为 `x.shape[:-2]` ,数据类型与 `x` 相同。
|
||||||
|
|
||||||
|
Tensor,行列式的绝对值的对数,shape为 `x.shape[:-2]` ,数据类型与 `x` 相同。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- **TypeError** - `x` 不为 Tensor。
|
||||||
|
- **TypeError** - `x` 的数据类型不为以下类型:float32、 float64、 complex64 和 complex128。
|
||||||
|
- **ValueError** - `x` 的最后两个维度大小不同。
|
||||||
|
- **ValueError** - `x` 的维数小于2。
|
|
@ -0,0 +1,21 @@
|
||||||
|
mindspore.ops.tril
|
||||||
|
===================
|
||||||
|
|
||||||
|
.. py:function:: mindspore.ops.tril(input_x, diagonal=0)
|
||||||
|
|
||||||
|
返回单个矩阵(二维Tensor)或批次输入矩阵的下三角形部分,其他位置的元素将被置零。
|
||||||
|
矩阵的下三角形部分定义为对角线本身和对角线以下的元素。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **input_x** (Tensor) - 输入Tensor。shape为 :math:`(x_1, x_2, ..., x_R)` ,其rank至少为2。
|
||||||
|
支持的数据类型有包括所有数值型和bool类型。
|
||||||
|
- **diagonal** (int,可选) - 指定对角线位置,默认值:0,指定主对角线。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Tensor,其数据类型和shape维度与 `input_x` 相同。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- **TypeError** - 如果 `input_x` 不是Tensor。
|
||||||
|
- **TypeError** - 如果 `diagonal` 不是int类型。
|
||||||
|
- **TypeError** - 如果 `input_x` 的数据类型既不是数值型也不是bool。
|
||||||
|
- **ValueError** - 如果 `input_x` 的秩小于2。
|
|
@ -172,7 +172,9 @@ BuiltInTypeMap &GetMethodMap() {
|
||||||
{"atan2", std::string("atan2")}, // P.Atan2
|
{"atan2", std::string("atan2")}, // P.Atan2
|
||||||
{"angle", std::string("angle")}, // C.reduce_any
|
{"angle", std::string("angle")}, // C.reduce_any
|
||||||
{"any", std::string("any_")}, // C.reduce_any
|
{"any", std::string("any_")}, // C.reduce_any
|
||||||
{"bincount", std::string("bincount")}, // C.reduce_any
|
{"bincount", std::string("bincount")}, // bincount
|
||||||
|
{"slogdet", std::string("slogdet")}, // slogdet
|
||||||
|
{"tril", std::string("tril")}, // tril
|
||||||
{"__add__", std::string("add")}, // C.add
|
{"__add__", std::string("add")}, // C.add
|
||||||
{"__sub__", std::string("sub")}, // C.sub
|
{"__sub__", std::string("sub")}, // C.sub
|
||||||
{"__mul__", std::string("mul")}, // C.mul
|
{"__mul__", std::string("mul")}, // C.mul
|
||||||
|
|
|
@ -292,6 +292,20 @@ def strides_(x):
|
||||||
return strides
|
return strides
|
||||||
|
|
||||||
|
|
||||||
|
def slogdet(x):
|
||||||
|
r"""
|
||||||
|
For details, please refer to :func:`mindspore.ops.slogdet`.
|
||||||
|
"""
|
||||||
|
return F.slogdet(x)
|
||||||
|
|
||||||
|
|
||||||
|
def tril(x, diagonal=0):
|
||||||
|
r"""
|
||||||
|
For details, please refer to :func:`mindspore.ops.tril`.
|
||||||
|
"""
|
||||||
|
return F.tril(x, diagonal)
|
||||||
|
|
||||||
|
|
||||||
def hasattr(x, attr): # pylint: disable=redefined-builtin
|
def hasattr(x, attr): # pylint: disable=redefined-builtin
|
||||||
"""
|
"""
|
||||||
Return whether an object has the attribute.
|
Return whether an object has the attribute.
|
||||||
|
|
|
@ -1987,6 +1987,20 @@ class Tensor(Tensor_):
|
||||||
new_shape = validator.prepare_shape_for_squeeze(self.shape, axis)
|
new_shape = validator.prepare_shape_for_squeeze(self.shape, axis)
|
||||||
return tensor_operator_registry.get('reshape')()(self, new_shape)
|
return tensor_operator_registry.get('reshape')()(self, new_shape)
|
||||||
|
|
||||||
|
def slogdet(self):
|
||||||
|
"""
|
||||||
|
For details, please refer to :func:`mindspore.ops.slogdet`.
|
||||||
|
"""
|
||||||
|
self._init_check()
|
||||||
|
return tensor_operator_registry.get('slogdet')(self)
|
||||||
|
|
||||||
|
def tril(self, diagonal=0):
|
||||||
|
"""
|
||||||
|
For details, please refer to :func:`mindspore.ops.tril`.
|
||||||
|
"""
|
||||||
|
self._init_check()
|
||||||
|
return tensor_operator_registry.get('tril')(self, diagonal)
|
||||||
|
|
||||||
def unsqueeze(self, dim):
|
def unsqueeze(self, dim):
|
||||||
"""
|
"""
|
||||||
For details, please refer to :func:`mindspore.ops.unsqueeze`.
|
For details, please refer to :func:`mindspore.ops.unsqueeze`.
|
||||||
|
|
|
@ -73,6 +73,7 @@ from .array_func import (
|
||||||
scatter_nd_div,
|
scatter_nd_div,
|
||||||
scatter_nd_max,
|
scatter_nd_max,
|
||||||
scatter_nd_min,
|
scatter_nd_min,
|
||||||
|
tril,
|
||||||
gather,
|
gather,
|
||||||
gather_d,
|
gather_d,
|
||||||
gather_elements,
|
gather_elements,
|
||||||
|
@ -218,6 +219,7 @@ from .math_func import (
|
||||||
log,
|
log,
|
||||||
logdet,
|
logdet,
|
||||||
log_matrix_determinant,
|
log_matrix_determinant,
|
||||||
|
slogdet,
|
||||||
matrix_determinant,
|
matrix_determinant,
|
||||||
linspace,
|
linspace,
|
||||||
matrix_solve,
|
matrix_solve,
|
||||||
|
|
|
@ -47,6 +47,7 @@ from mindspore.ops.operations.array_ops import (
|
||||||
Lstsq,
|
Lstsq,
|
||||||
Mvlgamma,
|
Mvlgamma,
|
||||||
CountNonZero,
|
CountNonZero,
|
||||||
|
Tril
|
||||||
)
|
)
|
||||||
from mindspore.ops.operations.array_ops import TensorScatterElements
|
from mindspore.ops.operations.array_ops import TensorScatterElements
|
||||||
from mindspore.common import Tensor
|
from mindspore.common import Tensor
|
||||||
|
@ -4687,6 +4688,66 @@ def split(x, split_size_or_sections, axis=0):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def tril(input_x, diagonal):
|
||||||
|
"""
|
||||||
|
Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices input,
|
||||||
|
the other elements of the result tensor out are set to 0.
|
||||||
|
The lower triangular part of the matrix is defined as the elements on and below the diagonal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_x (Tensor): A Tensor with shape :math:`(x_1, x_2, ..., x_R)`. The rank must be at least 2.
|
||||||
|
Supporting all number types including bool.
|
||||||
|
diagonal (int, optional): An optional attribute indicates the diagonal to consider, default: 0,
|
||||||
|
indicating the main diagonal.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, the same shape and data type as the input `x`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `x` is not a Tensor.
|
||||||
|
TypeError: If `diagonal` is not an int.
|
||||||
|
TypeError: If the type of `x` is neither number nor bool.
|
||||||
|
ValueError: If the rank of `x` is less than 2.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = Tensor(np.array([[ 1, 2, 3, 4],
|
||||||
|
... [ 5, 6, 7, 8],
|
||||||
|
... [10, 11, 12, 13],
|
||||||
|
... [14, 15, 16, 17]]))
|
||||||
|
>>> result = ops.tril(x)
|
||||||
|
>>> print(result)
|
||||||
|
[[ 1 0 0 0]
|
||||||
|
[ 5 6 0 0]
|
||||||
|
[10 11 12 0]
|
||||||
|
[14 15 16 17]]
|
||||||
|
>>> x = Tensor(np.array([[ 1, 2, 3, 4],
|
||||||
|
... [ 5, 6, 7, 8],
|
||||||
|
... [10, 11, 12, 13],
|
||||||
|
... [14, 15, 16, 17]]))
|
||||||
|
>>> result = ops.tril(x, diagonal=1)
|
||||||
|
>>> print(result)
|
||||||
|
[[ 1 2 0 0]
|
||||||
|
[ 5 6 7 0]
|
||||||
|
[10 11 12 13]
|
||||||
|
[14 15 16 17]]
|
||||||
|
>>> x = Tensor(np.array([[ 1, 2, 3, 4],
|
||||||
|
... [ 5, 6, 7, 8],
|
||||||
|
... [10, 11, 12, 13],
|
||||||
|
... [14, 15, 16, 17]]))
|
||||||
|
>>> result = ops.tril(x, diagonal=-1)
|
||||||
|
>>> print(result)
|
||||||
|
[[ 0 0 0 0]
|
||||||
|
[ 5 0 0 0]
|
||||||
|
[10 11 0 0]
|
||||||
|
[14 15 16 0]]
|
||||||
|
"""
|
||||||
|
tril_ = Tril(diagonal)
|
||||||
|
return tril_(input_x)
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _canonicalize_axis(axis, ndim):
|
def _canonicalize_axis(axis, ndim):
|
||||||
"""
|
"""
|
||||||
|
@ -5792,6 +5853,7 @@ __all__ = [
|
||||||
'scatter_div',
|
'scatter_div',
|
||||||
'scatter_update',
|
'scatter_update',
|
||||||
'select',
|
'select',
|
||||||
|
'tril',
|
||||||
'nonzero',
|
'nonzero',
|
||||||
'matrix_diag',
|
'matrix_diag',
|
||||||
'matrix_diag_part',
|
'matrix_diag_part',
|
||||||
|
|
|
@ -3175,6 +3175,42 @@ def matrix_solve(matrix, rhs, adjoint=False): # pylint: disable=redefined-outer
|
||||||
return matrix_solve_(matrix, rhs)
|
return matrix_solve_(matrix, rhs)
|
||||||
|
|
||||||
|
|
||||||
|
def slogdet(x):
|
||||||
|
r"""
|
||||||
|
Computes the sign and the log of the absolute value of the determinant of one or more square matrices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): A matrix to be calculated, its shape is :math:`[..., M, M]`.
|
||||||
|
The matrix must be at least two dimensions, and the last two
|
||||||
|
dimensions must be the same size. Data type must be float32, float64, complex64 or complex128.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor. The signs of the log determinants. The shape is :math:`x.shape[:-2]`,
|
||||||
|
and the dtype is same as `x`.
|
||||||
|
|
||||||
|
Tensor. The absolute values of the log determinants. The shape is :math:`x.shape[:-2]`, and
|
||||||
|
the dtype is same as `x`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `x` is not a Tensor.
|
||||||
|
TypeError: If dtype of `x` not float32, float64, complex64 or complex128.
|
||||||
|
ValueError: If the last two dimensions of `x` is not same size.
|
||||||
|
ValueError: If the dimension of `x` is less than 2.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input_x = Tensor(np.array([[[-4.5, -1.5], [7.0, 6.0]], [[2.5, 0.5], [3.0, 9.0]]]), mindspore.float32)
|
||||||
|
>>> sign, output = ops.slogdet(input_x)
|
||||||
|
>>> print(sign)
|
||||||
|
[-1. 1.]
|
||||||
|
>>> print(output)
|
||||||
|
[2.80336046e+00 3.04452229e+00]
|
||||||
|
"""
|
||||||
|
return log_matrix_determinant_(x)
|
||||||
|
|
||||||
|
|
||||||
def truncate_div(x, y):
|
def truncate_div(x, y):
|
||||||
"""
|
"""
|
||||||
Divides the first input tensor by the second input tensor element-wise for integer types, negative numbers will
|
Divides the first input tensor by the second input tensor element-wise for integer types, negative numbers will
|
||||||
|
@ -9751,6 +9787,7 @@ __all__ = [
|
||||||
'chain_matmul',
|
'chain_matmul',
|
||||||
'hann_window',
|
'hann_window',
|
||||||
'log2',
|
'log2',
|
||||||
|
'slogdet',
|
||||||
'xlogy',
|
'xlogy',
|
||||||
'log10',
|
'log10',
|
||||||
'log1p',
|
'log1p',
|
||||||
|
|
|
@ -133,6 +133,8 @@ tensor_operator_registry.register('real', real)
|
||||||
tensor_operator_registry.register('reciprocal', reciprocal)
|
tensor_operator_registry.register('reciprocal', reciprocal)
|
||||||
tensor_operator_registry.register('rsqrt', rsqrt)
|
tensor_operator_registry.register('rsqrt', rsqrt)
|
||||||
tensor_operator_registry.register('bincount', bincount)
|
tensor_operator_registry.register('bincount', bincount)
|
||||||
|
tensor_operator_registry.register('slogdet', slogdet)
|
||||||
|
tensor_operator_registry.register('tril', tril)
|
||||||
tensor_operator_registry.register('sqrt', sqrt)
|
tensor_operator_registry.register('sqrt', sqrt)
|
||||||
tensor_operator_registry.register('square', square)
|
tensor_operator_registry.register('square', square)
|
||||||
tensor_operator_registry.register('sub', sub)
|
tensor_operator_registry.register('sub', sub)
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import ops
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def construct(self, x):
|
||||||
|
return ops.slogdet(x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||||
|
def test_slogdet(mode):
|
||||||
|
"""
|
||||||
|
Feature: slogdet
|
||||||
|
Description: Verify the result of slogdet
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
ms.set_context(mode=mode)
|
||||||
|
x = Tensor([[-1.8297, -0.8474, 1.0292], [-1.2167, 0.5574, -0.6753], [-0.6702, 0.2276, 1.2421]])
|
||||||
|
net = Net()
|
||||||
|
output1, output2 = net(x)
|
||||||
|
expect_output1 = np.array(-1, dtype=np.float32)
|
||||||
|
expect_output2 = np.array(1.13549, dtype=np.float32)
|
||||||
|
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||||
|
assert np.allclose(output2.asnumpy(), expect_output2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||||
|
def test_slogdet_complex(mode):
|
||||||
|
"""
|
||||||
|
Feature: slogdet
|
||||||
|
Description: Verify the result of slogdet
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
ms.set_context(mode=mode)
|
||||||
|
x = Tensor([[-1.5 + 7.8j, 3 + 5.75j, 2 + 2.4j],
|
||||||
|
[-6.4 + 485.4j, 45 + 3.14j, 45 + 453j],
|
||||||
|
[-3.5 + 5.8j, 63 + 12.75j, -5 + 6.4j]], dtype=ms.complex64)
|
||||||
|
net = Net()
|
||||||
|
output1, output2 = net(x)
|
||||||
|
expect_output1 = np.array(0.749919+0.66153j, dtype=np.complex)
|
||||||
|
expect_output2 = np.array(12.0614+0j, dtype=np.complex)
|
||||||
|
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||||
|
assert np.allclose(output2.asnumpy(), expect_output2)
|
|
@ -0,0 +1,47 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import ops
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def construct(self, x, diagonal=0):
|
||||||
|
return ops.tril(x, diagonal)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||||
|
def test_tril(mode):
|
||||||
|
"""
|
||||||
|
Feature: tril
|
||||||
|
Description: Verify the result of tril
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
ms.set_context(mode=mode)
|
||||||
|
x = Tensor([[-1.8297, -0.8474, 1.0292], [-1.2167, 0.5574, -0.6753], [-0.6702, 0.2276, 1.2421]])
|
||||||
|
net = Net()
|
||||||
|
output = net(x)
|
||||||
|
expect_output = np.array([[-1.8297, 0., 0.], [-1.2167, 0.5574, 0.], [-0.6702, 0.2276, 1.2421]], dtype=np.float32)
|
||||||
|
assert np.allclose(output.asnumpy(), expect_output)
|
|
@ -0,0 +1,72 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def construct(self, x):
|
||||||
|
return x.slogdet()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||||
|
def test_slogdet(mode):
|
||||||
|
"""
|
||||||
|
Feature: slogdet
|
||||||
|
Description: Verify the result of slogdet
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
ms.set_context(mode=mode)
|
||||||
|
x = Tensor([[-1.8297, -0.8474, 1.0292], [-1.2167, 0.5574, -0.6753], [-0.6702, 0.2276, 1.2421]])
|
||||||
|
net = Net()
|
||||||
|
output1, output2 = net(x)
|
||||||
|
expect_output1 = np.array(-1, dtype=np.float32)
|
||||||
|
expect_output2 = np.array(1.13549, dtype=np.float32)
|
||||||
|
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||||
|
assert np.allclose(output2.asnumpy(), expect_output2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||||
|
def test_slogdet_complex(mode):
|
||||||
|
"""
|
||||||
|
Feature: slogdet
|
||||||
|
Description: Verify the result of slogdet
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
ms.set_context(mode=mode)
|
||||||
|
x = Tensor([[-1.5 + 7.8j, 3 + 5.75j, 2 + 2.4j],
|
||||||
|
[-6.4 + 485.4j, 45 + 3.14j, 45 + 453j],
|
||||||
|
[-3.5 + 5.8j, 63 + 12.75j, -5 + 6.4j]], dtype=ms.complex64)
|
||||||
|
net = Net()
|
||||||
|
output1, output2 = net(x)
|
||||||
|
expect_output1 = np.array(0.749919+0.66153j, dtype=np.complex)
|
||||||
|
expect_output2 = np.array(12.0614+0j, dtype=np.complex)
|
||||||
|
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||||
|
assert np.allclose(output2.asnumpy(), expect_output2)
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def construct(self, x, diagonal=0):
|
||||||
|
return x.tril(diagonal)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||||
|
def test_tril(mode):
|
||||||
|
"""
|
||||||
|
Feature: tril
|
||||||
|
Description: Verify the result of tril
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
ms.set_context(mode=mode)
|
||||||
|
x = Tensor([[-1.8297, -0.8474, 1.0292], [-1.2167, 0.5574, -0.6753], [-0.6702, 0.2276, 1.2421]])
|
||||||
|
net = Net()
|
||||||
|
output = net(x)
|
||||||
|
expect_output = np.array([[-1.8297, 0., 0.], [-1.2167, 0.5574, 0.], [-0.6702, 0.2276, 1.2421]], dtype=np.float32)
|
||||||
|
assert np.allclose(output.asnumpy(), expect_output)
|
Loading…
Reference in New Issue