forked from mindspore-Ecosystem/mindspore
!46729 api: slogdet tril
Merge pull request !46729 from 于振华/api_tril_slogdet_1209
This commit is contained in:
commit
5db6dbb8d4
|
@ -463,6 +463,7 @@ Array操作
|
|||
mindspore.ops.shape
|
||||
mindspore.ops.size
|
||||
mindspore.ops.slice
|
||||
mindspore.ops.slogdet
|
||||
mindspore.ops.space_to_batch_nd
|
||||
mindspore.ops.sparse_segment_mean
|
||||
mindspore.ops.split
|
||||
|
@ -479,6 +480,7 @@ Array操作
|
|||
mindspore.ops.tensor_scatter_elements
|
||||
mindspore.ops.tensor_split
|
||||
mindspore.ops.tile
|
||||
mindspore.ops.tril
|
||||
mindspore.ops.top_k
|
||||
mindspore.ops.transpose
|
||||
mindspore.ops.unbind
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.slogdet
|
||||
========================
|
||||
|
||||
.. py:method:: mindspore.Tensor.slogdet()
|
||||
|
||||
详情请参考 :func:`mindspore.ops.slogdet`。
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.tril
|
||||
=====================
|
||||
|
||||
.. py:method:: mindspore.Tensor.tril(diagonal=0)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.tril`。
|
|
@ -245,6 +245,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.sin
|
||||
mindspore.Tensor.sinc
|
||||
mindspore.Tensor.size
|
||||
mindspore.Tensor.slogdet
|
||||
mindspore.Tensor.soft_shrink
|
||||
mindspore.Tensor.split
|
||||
mindspore.Tensor.sqrt
|
||||
|
@ -272,6 +273,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.top_k
|
||||
mindspore.Tensor.trace
|
||||
mindspore.Tensor.transpose
|
||||
mindspore.Tensor.tril
|
||||
mindspore.Tensor.triu
|
||||
mindspore.Tensor.true_divide
|
||||
mindspore.Tensor.unbind
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
mindspore.ops.slogdet
|
||||
=====================
|
||||
|
||||
.. py:function:: mindspore.ops.slogdet(x)
|
||||
|
||||
对一个或多个方阵行列式的绝对值取对数,返回其符号和值。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 输入Tensor,shape为 :math:`[..., M, M]` 。矩阵必须至少有两个维度,最后两个维度尺寸必须相同。支持的数据类型为float32、float64、complex64或complex128。
|
||||
|
||||
返回:
|
||||
Tensor,行列式的绝对值的对数的符号,shape为 `x.shape[:-2]` ,数据类型与 `x` 相同。
|
||||
|
||||
Tensor,行列式的绝对值的对数,shape为 `x.shape[:-2]` ,数据类型与 `x` 相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 不为 Tensor。
|
||||
- **TypeError** - `x` 的数据类型不为以下类型:float32、 float64、 complex64 和 complex128。
|
||||
- **ValueError** - `x` 的最后两个维度大小不同。
|
||||
- **ValueError** - `x` 的维数小于2。
|
|
@ -0,0 +1,21 @@
|
|||
mindspore.ops.tril
|
||||
===================
|
||||
|
||||
.. py:function:: mindspore.ops.tril(input_x, diagonal=0)
|
||||
|
||||
返回单个矩阵(二维Tensor)或批次输入矩阵的下三角形部分,其他位置的元素将被置零。
|
||||
矩阵的下三角形部分定义为对角线本身和对角线以下的元素。
|
||||
|
||||
参数:
|
||||
- **input_x** (Tensor) - 输入Tensor。shape为 :math:`(x_1, x_2, ..., x_R)` ,其rank至少为2。
|
||||
支持的数据类型有包括所有数值型和bool类型。
|
||||
- **diagonal** (int,可选) - 指定对角线位置,默认值:0,指定主对角线。
|
||||
|
||||
返回:
|
||||
Tensor,其数据类型和shape维度与 `input_x` 相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `input_x` 不是Tensor。
|
||||
- **TypeError** - 如果 `diagonal` 不是int类型。
|
||||
- **TypeError** - 如果 `input_x` 的数据类型既不是数值型也不是bool。
|
||||
- **ValueError** - 如果 `input_x` 的秩小于2。
|
|
@ -172,7 +172,9 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"atan2", std::string("atan2")}, // P.Atan2
|
||||
{"angle", std::string("angle")}, // C.reduce_any
|
||||
{"any", std::string("any_")}, // C.reduce_any
|
||||
{"bincount", std::string("bincount")}, // C.reduce_any
|
||||
{"bincount", std::string("bincount")}, // bincount
|
||||
{"slogdet", std::string("slogdet")}, // slogdet
|
||||
{"tril", std::string("tril")}, // tril
|
||||
{"__add__", std::string("add")}, // C.add
|
||||
{"__sub__", std::string("sub")}, // C.sub
|
||||
{"__mul__", std::string("mul")}, // C.mul
|
||||
|
|
|
@ -292,6 +292,20 @@ def strides_(x):
|
|||
return strides
|
||||
|
||||
|
||||
def slogdet(x):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.slogdet`.
|
||||
"""
|
||||
return F.slogdet(x)
|
||||
|
||||
|
||||
def tril(x, diagonal=0):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.tril`.
|
||||
"""
|
||||
return F.tril(x, diagonal)
|
||||
|
||||
|
||||
def hasattr(x, attr): # pylint: disable=redefined-builtin
|
||||
"""
|
||||
Return whether an object has the attribute.
|
||||
|
|
|
@ -1987,6 +1987,20 @@ class Tensor(Tensor_):
|
|||
new_shape = validator.prepare_shape_for_squeeze(self.shape, axis)
|
||||
return tensor_operator_registry.get('reshape')()(self, new_shape)
|
||||
|
||||
def slogdet(self):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.slogdet`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('slogdet')(self)
|
||||
|
||||
def tril(self, diagonal=0):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.tril`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('tril')(self, diagonal)
|
||||
|
||||
def unsqueeze(self, dim):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.unsqueeze`.
|
||||
|
|
|
@ -73,6 +73,7 @@ from .array_func import (
|
|||
scatter_nd_div,
|
||||
scatter_nd_max,
|
||||
scatter_nd_min,
|
||||
tril,
|
||||
gather,
|
||||
gather_d,
|
||||
gather_elements,
|
||||
|
@ -218,6 +219,7 @@ from .math_func import (
|
|||
log,
|
||||
logdet,
|
||||
log_matrix_determinant,
|
||||
slogdet,
|
||||
matrix_determinant,
|
||||
linspace,
|
||||
matrix_solve,
|
||||
|
|
|
@ -47,6 +47,7 @@ from mindspore.ops.operations.array_ops import (
|
|||
Lstsq,
|
||||
Mvlgamma,
|
||||
CountNonZero,
|
||||
Tril
|
||||
)
|
||||
from mindspore.ops.operations.array_ops import TensorScatterElements
|
||||
from mindspore.common import Tensor
|
||||
|
@ -4687,6 +4688,66 @@ def split(x, split_size_or_sections, axis=0):
|
|||
return res
|
||||
|
||||
|
||||
def tril(input_x, diagonal):
|
||||
"""
|
||||
Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices input,
|
||||
the other elements of the result tensor out are set to 0.
|
||||
The lower triangular part of the matrix is defined as the elements on and below the diagonal.
|
||||
|
||||
Args:
|
||||
input_x (Tensor): A Tensor with shape :math:`(x_1, x_2, ..., x_R)`. The rank must be at least 2.
|
||||
Supporting all number types including bool.
|
||||
diagonal (int, optional): An optional attribute indicates the diagonal to consider, default: 0,
|
||||
indicating the main diagonal.
|
||||
|
||||
Returns:
|
||||
Tensor, the same shape and data type as the input `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If `diagonal` is not an int.
|
||||
TypeError: If the type of `x` is neither number nor bool.
|
||||
ValueError: If the rank of `x` is less than 2.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[ 1, 2, 3, 4],
|
||||
... [ 5, 6, 7, 8],
|
||||
... [10, 11, 12, 13],
|
||||
... [14, 15, 16, 17]]))
|
||||
>>> result = ops.tril(x)
|
||||
>>> print(result)
|
||||
[[ 1 0 0 0]
|
||||
[ 5 6 0 0]
|
||||
[10 11 12 0]
|
||||
[14 15 16 17]]
|
||||
>>> x = Tensor(np.array([[ 1, 2, 3, 4],
|
||||
... [ 5, 6, 7, 8],
|
||||
... [10, 11, 12, 13],
|
||||
... [14, 15, 16, 17]]))
|
||||
>>> result = ops.tril(x, diagonal=1)
|
||||
>>> print(result)
|
||||
[[ 1 2 0 0]
|
||||
[ 5 6 7 0]
|
||||
[10 11 12 13]
|
||||
[14 15 16 17]]
|
||||
>>> x = Tensor(np.array([[ 1, 2, 3, 4],
|
||||
... [ 5, 6, 7, 8],
|
||||
... [10, 11, 12, 13],
|
||||
... [14, 15, 16, 17]]))
|
||||
>>> result = ops.tril(x, diagonal=-1)
|
||||
>>> print(result)
|
||||
[[ 0 0 0 0]
|
||||
[ 5 0 0 0]
|
||||
[10 11 0 0]
|
||||
[14 15 16 0]]
|
||||
"""
|
||||
tril_ = Tril(diagonal)
|
||||
return tril_(input_x)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _canonicalize_axis(axis, ndim):
|
||||
"""
|
||||
|
@ -5792,6 +5853,7 @@ __all__ = [
|
|||
'scatter_div',
|
||||
'scatter_update',
|
||||
'select',
|
||||
'tril',
|
||||
'nonzero',
|
||||
'matrix_diag',
|
||||
'matrix_diag_part',
|
||||
|
|
|
@ -3175,6 +3175,42 @@ def matrix_solve(matrix, rhs, adjoint=False): # pylint: disable=redefined-outer
|
|||
return matrix_solve_(matrix, rhs)
|
||||
|
||||
|
||||
def slogdet(x):
|
||||
r"""
|
||||
Computes the sign and the log of the absolute value of the determinant of one or more square matrices.
|
||||
|
||||
Args:
|
||||
x (Tensor): A matrix to be calculated, its shape is :math:`[..., M, M]`.
|
||||
The matrix must be at least two dimensions, and the last two
|
||||
dimensions must be the same size. Data type must be float32, float64, complex64 or complex128.
|
||||
|
||||
Returns:
|
||||
Tensor. The signs of the log determinants. The shape is :math:`x.shape[:-2]`,
|
||||
and the dtype is same as `x`.
|
||||
|
||||
Tensor. The absolute values of the log determinants. The shape is :math:`x.shape[:-2]`, and
|
||||
the dtype is same as `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If dtype of `x` not float32, float64, complex64 or complex128.
|
||||
ValueError: If the last two dimensions of `x` is not same size.
|
||||
ValueError: If the dimension of `x` is less than 2.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([[[-4.5, -1.5], [7.0, 6.0]], [[2.5, 0.5], [3.0, 9.0]]]), mindspore.float32)
|
||||
>>> sign, output = ops.slogdet(input_x)
|
||||
>>> print(sign)
|
||||
[-1. 1.]
|
||||
>>> print(output)
|
||||
[2.80336046e+00 3.04452229e+00]
|
||||
"""
|
||||
return log_matrix_determinant_(x)
|
||||
|
||||
|
||||
def truncate_div(x, y):
|
||||
"""
|
||||
Divides the first input tensor by the second input tensor element-wise for integer types, negative numbers will
|
||||
|
@ -9751,6 +9787,7 @@ __all__ = [
|
|||
'chain_matmul',
|
||||
'hann_window',
|
||||
'log2',
|
||||
'slogdet',
|
||||
'xlogy',
|
||||
'log10',
|
||||
'log1p',
|
||||
|
|
|
@ -133,6 +133,8 @@ tensor_operator_registry.register('real', real)
|
|||
tensor_operator_registry.register('reciprocal', reciprocal)
|
||||
tensor_operator_registry.register('rsqrt', rsqrt)
|
||||
tensor_operator_registry.register('bincount', bincount)
|
||||
tensor_operator_registry.register('slogdet', slogdet)
|
||||
tensor_operator_registry.register('tril', tril)
|
||||
tensor_operator_registry.register('sqrt', sqrt)
|
||||
tensor_operator_registry.register('square', square)
|
||||
tensor_operator_registry.register('sub', sub)
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return ops.slogdet(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_slogdet(mode):
|
||||
"""
|
||||
Feature: slogdet
|
||||
Description: Verify the result of slogdet
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor([[-1.8297, -0.8474, 1.0292], [-1.2167, 0.5574, -0.6753], [-0.6702, 0.2276, 1.2421]])
|
||||
net = Net()
|
||||
output1, output2 = net(x)
|
||||
expect_output1 = np.array(-1, dtype=np.float32)
|
||||
expect_output2 = np.array(1.13549, dtype=np.float32)
|
||||
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||
assert np.allclose(output2.asnumpy(), expect_output2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_slogdet_complex(mode):
|
||||
"""
|
||||
Feature: slogdet
|
||||
Description: Verify the result of slogdet
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor([[-1.5 + 7.8j, 3 + 5.75j, 2 + 2.4j],
|
||||
[-6.4 + 485.4j, 45 + 3.14j, 45 + 453j],
|
||||
[-3.5 + 5.8j, 63 + 12.75j, -5 + 6.4j]], dtype=ms.complex64)
|
||||
net = Net()
|
||||
output1, output2 = net(x)
|
||||
expect_output1 = np.array(0.749919+0.66153j, dtype=np.complex)
|
||||
expect_output2 = np.array(12.0614+0j, dtype=np.complex)
|
||||
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||
assert np.allclose(output2.asnumpy(), expect_output2)
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, diagonal=0):
|
||||
return ops.tril(x, diagonal)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_tril(mode):
|
||||
"""
|
||||
Feature: tril
|
||||
Description: Verify the result of tril
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor([[-1.8297, -0.8474, 1.0292], [-1.2167, 0.5574, -0.6753], [-0.6702, 0.2276, 1.2421]])
|
||||
net = Net()
|
||||
output = net(x)
|
||||
expect_output = np.array([[-1.8297, 0., 0.], [-1.2167, 0.5574, 0.], [-0.6702, 0.2276, 1.2421]], dtype=np.float32)
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x.slogdet()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_slogdet(mode):
|
||||
"""
|
||||
Feature: slogdet
|
||||
Description: Verify the result of slogdet
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor([[-1.8297, -0.8474, 1.0292], [-1.2167, 0.5574, -0.6753], [-0.6702, 0.2276, 1.2421]])
|
||||
net = Net()
|
||||
output1, output2 = net(x)
|
||||
expect_output1 = np.array(-1, dtype=np.float32)
|
||||
expect_output2 = np.array(1.13549, dtype=np.float32)
|
||||
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||
assert np.allclose(output2.asnumpy(), expect_output2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_slogdet_complex(mode):
|
||||
"""
|
||||
Feature: slogdet
|
||||
Description: Verify the result of slogdet
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor([[-1.5 + 7.8j, 3 + 5.75j, 2 + 2.4j],
|
||||
[-6.4 + 485.4j, 45 + 3.14j, 45 + 453j],
|
||||
[-3.5 + 5.8j, 63 + 12.75j, -5 + 6.4j]], dtype=ms.complex64)
|
||||
net = Net()
|
||||
output1, output2 = net(x)
|
||||
expect_output1 = np.array(0.749919+0.66153j, dtype=np.complex)
|
||||
expect_output2 = np.array(12.0614+0j, dtype=np.complex)
|
||||
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||
assert np.allclose(output2.asnumpy(), expect_output2)
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, diagonal=0):
|
||||
return x.tril(diagonal)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_tril(mode):
|
||||
"""
|
||||
Feature: tril
|
||||
Description: Verify the result of tril
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor([[-1.8297, -0.8474, 1.0292], [-1.2167, 0.5574, -0.6753], [-0.6702, 0.2276, 1.2421]])
|
||||
net = Net()
|
||||
output = net(x)
|
||||
expect_output = np.array([[-1.8297, 0., 0.], [-1.2167, 0.5574, 0.], [-0.6702, 0.2276, 1.2421]], dtype=np.float32)
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
Loading…
Reference in New Issue