add det method and st test cases.

This commit is contained in:
hezhenhao1 2022-01-05 11:00:24 +08:00
parent eaacc10cab
commit 43ffa30789
2 changed files with 143 additions and 0 deletions

View File

@ -649,3 +649,74 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
rhs_vector = m_lu.ndim == b.ndim + 1 rhs_vector = m_lu.ndim == b.ndim + 1
x = lu_solve_core(m_lu, permutation, b, trans) x = lu_solve_core(m_lu, permutation, b, trans)
return x[..., 0] if rhs_vector else x return x[..., 0] if rhs_vector else x
def _det_2x2(a):
return (a[..., 0, 0] * a[..., 1, 1] -
a[..., 0, 1] * a[..., 1, 0])
def _det_3x3(a):
return (a[..., 0, 0] * a[..., 1, 1] * a[..., 2, 2] +
a[..., 0, 1] * a[..., 1, 2] * a[..., 2, 0] +
a[..., 0, 2] * a[..., 1, 0] * a[..., 2, 1] -
a[..., 0, 2] * a[..., 1, 1] * a[..., 2, 0] -
a[..., 0, 0] * a[..., 1, 2] * a[..., 2, 1] -
a[..., 0, 1] * a[..., 1, 0] * a[..., 2, 2])
def det(a, overwrite_a=False, check_finite=True):
"""
Compute the determinant of a matrix
The determinant of a square matrix is a value derived arithmetically
from the coefficients of the matrix.
The determinant for a 3x3 matrix, for example, is computed as follows::
a b c
d e f = A
g h i
det(A) = a*e*i + b*f*g + c*d*h - c*e*g - b*d*i - a*f*h
Args:
a (Tensor): A square matrix to compute. Note that if the input tensor is not a `float`,
then it will be casted to :class:`mstype.float32`.
overwrite_a (bool, optional): Allow overwriting data in a (may enhance performance).
check_finite (bool, optional): Whether to check that the input matrix contains
only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Raises:
ValueError: If :math:`a` is not square.
Returns:
Tensor, Determinant of `a`.
Examples:
>>> import numpy as onp
>>> from mindspore.common import Tensor
>>> from mindspore.scipy.linalg import det
>>> a = Tensor(onp.array([[0, 2, 3], [4, 5, 6], [7, 8, 9]])).astype(onp.float64)
>>> det(a)
3.0
"""
# special case
if a.ndim >= 2 and a.shape[-1] == 2 and a.shape[-2] == 2:
return _det_2x2(a)
if a.ndim >= 2 and a.shape[-1] == 3 and a.shape[-2] == 3:
return _det_3x3(a)
if a.ndim < 2 or a.shape[-1] != a.shape[-2]:
_raise_value_error("Arguments to det must be [..., n, n], but got shape {}.".format(a.shape))
if F.dtype(a) not in float_types:
a = F.cast(a, mstype.float32)
lu_matrix, pivot = lu_factor(a)
diag = lu_matrix.diagonal(axis1=-2, axis2=-1)
pivot_not_equal = (pivot != mnp.arange(a.shape[-1])).astype(mstype.int64)
pivot_sign = mnp.count_nonzero(pivot_not_equal, axis=-1)
sign = -2. * (pivot_sign % 2) + 1.
return sign * P.ReduceProd(keep_dims=False)(diag, -1)

View File

@ -23,6 +23,7 @@ import mindspore.nn as nn
import mindspore.scipy as msp import mindspore.scipy as msp
from mindspore import context, Tensor from mindspore import context, Tensor
import mindspore.numpy as mnp import mindspore.numpy as mnp
from mindspore.scipy.linalg import det
from tests.st.scipy_st.utils import match_array, create_full_rank_matrix, create_sym_pos_matrix, \ from tests.st.scipy_st.utils import match_array, create_full_rank_matrix, create_sym_pos_matrix, \
create_random_rank_matrix create_random_rank_matrix
@ -328,6 +329,50 @@ def test_lu_solve(n: int, dtype):
assert onp.allclose(msp_x.asnumpy(), osp_x, rtol=rtol, atol=atol) assert onp.allclose(msp_x.asnumpy(), osp_x, rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(3, 3), (5, 5), (10, 10), (20, 20)])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
def test_det(shape, dtype):
"""
Feature: ALL To ALL
Description: test cases for det
Expectation: the result match to scipy
"""
a = onp.random.random(shape).astype(dtype)
sp_det = osp.linalg.det(a)
tensor_a = Tensor(a)
ms_det = msp.linalg.det(tensor_a)
rtol = 1.e-5
atol = 1.e-5
assert onp.allclose(ms_det.asnumpy(), sp_det, rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(2, 3, 3), (2, 3, 5, 5)])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
def test_batch_det(shape, dtype):
"""
Feature: ALL To ALL
Description: test batch cases for det
Expectation: the result match to scipy
"""
a = onp.random.random(shape).astype(dtype)
tensor_a = Tensor(a)
ms_det = msp.linalg.det(tensor_a)
sp_det = onp.empty(shape=ms_det.shape, dtype=dtype)
for index, _ in onp.ndenumerate(sp_det):
sp_det[index] = osp.linalg.det(a[index])
rtol = 1.e-5
atol = 1.e-5
assert onp.allclose(ms_det.asnumpy(), sp_det, rtol=rtol, atol=atol)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@ -351,3 +396,30 @@ def test_block_diag_graph(args):
scipy_res = osp.linalg.block_diag(*args) scipy_res = osp.linalg.block_diag(*args)
match_array(ms_res.asnumpy(), scipy_res) match_array(ms_res.asnumpy(), scipy_res)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(3, 3), (5, 5), (10, 10), (20, 20)])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
def test_det_graph(shape, dtype):
"""
Feature: ALL To ALL
Description: test cases for det in graph mode
Expectation: the result match to scipy
"""
context.set_context(mode=context.GRAPH_MODE)
class TestNet(nn.Cell):
def construct(self, a):
return det(a)
a = onp.random.random(shape).astype(dtype)
sp_det = osp.linalg.det(a)
tensor_a = Tensor(a)
ms_det = TestNet()(tensor_a)
rtol = 1.e-5
atol = 1.e-5
assert onp.allclose(ms_det.asnumpy(), sp_det, rtol=rtol, atol=atol)